Remove pubkey flag from init

This commit is contained in:
katexochen 2022-03-28 08:58:56 +02:00 committed by Paul Meyer
parent 7275f318f8
commit 5cf8f83ed8
2 changed files with 69 additions and 106 deletions

View File

@ -35,8 +35,8 @@ func newInitCmd() *cobra.Command {
RunE: runInitialize, RunE: runInitialize,
} }
cmd.Flags().String("privatekey", "", "path to your private key.") cmd.Flags().String("privatekey", "", "path to your private key.")
cmd.Flags().String("publickey", "", "path to your public key.")
cmd.Flags().String("master-secret", "", "path to base64 encoded master secret.") cmd.Flags().String("master-secret", "", "path to base64 encoded master secret.")
cmd.Flags().Bool("wg-autoconfig", false, "enable automatic configuration of WireGuard interface")
cmd.Flags().Bool("autoscale", false, "enable kubernetes cluster-autoscaler") cmd.Flags().Bool("autoscale", false, "enable kubernetes cluster-autoscaler")
return cmd return cmd
} }
@ -228,11 +228,11 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler, config *config.C
if err != nil { if err != nil {
return flagArgs{}, err return flagArgs{}, err
} }
userPublicKeyPath, err := cmd.Flags().GetString("publickey") userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
if err != nil { if err != nil {
return flagArgs{}, err return flagArgs{}, err
} }
userPrivKey, userPubKey, err := readVpnKey(fileHandler, userPrivKeyPath, userPublicKeyPath) autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
if err != nil { if err != nil {
return flagArgs{}, err return flagArgs{}, err
} }
@ -252,7 +252,7 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler, config *config.C
return flagArgs{ return flagArgs{
userPrivKey: userPrivKey, userPrivKey: userPrivKey,
userPubKey: userPubKey, userPubKey: userPubKey,
autoconfigureWG: userPrivKeyPath != "", autoconfigureWG: autoconfigureWG,
autoscale: autoscale, autoscale: autoscale,
masterSecret: masterSecret, masterSecret: masterSecret,
}, nil }, nil
@ -267,28 +267,27 @@ type flagArgs struct {
autoscale bool autoscale bool
} }
func readVpnKey(fileHandler file.Handler, privKeyPath, publicKeyPath string) (privKey, pubKey []byte, err error) { func readOrGenerateVPNKey(fileHandler file.Handler, privKeyPath string) (privKey, pubKey []byte, err error) {
if privKeyPath != "" { var privKeyParsed wgtypes.Key
if privKeyPath == "" {
privKeyParsed, err = wgtypes.GeneratePrivateKey()
if err != nil {
return nil, nil, err
}
privKey = []byte(privKeyParsed.String())
} else {
privKey, err = fileHandler.Read(privKeyPath) privKey, err = fileHandler.Read(privKeyPath)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
privKeyParsed, err := wgtypes.ParseKey(string(privKey)) privKeyParsed, err = wgtypes.ParseKey(string(privKey))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
pubKey = []byte(privKeyParsed.PublicKey().String())
} else if publicKeyPath != "" {
pubKey, err = fileHandler.Read(publicKeyPath)
if err != nil {
return nil, nil, err
}
if err := checkBase64WGKey(pubKey); err != nil {
return nil, nil, fmt.Errorf("wireguard public key is invalid: %w", err)
}
} else {
return nil, nil, errors.New("neither path to public nor private key provided")
} }
pubKey = []byte(privKeyParsed.PublicKey().String())
return privKey, pubKey, nil return privKey, pubKey, nil
} }
@ -308,17 +307,6 @@ func ipsToEndpoints(ips []string, port string) []string {
return endpoints return endpoints
} }
func checkBase64WGKey(b []byte) error {
keyStr, err := base64.StdEncoding.DecodeString(string(b))
if err != nil {
return errors.New("key can't be decoded")
}
if length := len(keyStr); length != wireguardKeyLength {
return fmt.Errorf("key has invalid length %d", length)
}
return nil
}
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret. // readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string, config *config.Config) ([]byte, error) { func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string, config *config.Config) ([]byte, error) {
if filename != "" { if filename != "" {

View File

@ -96,7 +96,7 @@ func TestInitialize(t *testing.T) {
{ {
kubeconfig: "kubeconfig", kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp", clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey", coordinatorVpnKey: testKey,
ownerID: "ownerID", ownerID: "ownerID",
clusterID: "clusterID", clusterID: "clusterID",
}, },
@ -109,7 +109,7 @@ func TestInitialize(t *testing.T) {
client protoClient client protoClient
serviceAccountCreator stubServiceAccountCreator serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter waiter statusWaiter
pubKey string privKey string
errExpected bool errExpected bool
}{ }{
"initialize some ec2 instances": { "initialize some ec2 instances": {
@ -117,30 +117,30 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{ client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some gcp instances": { "initialize some gcp instances": {
existingState: testGcpState, existingState: testGcpState,
client: &fakeProtoClient{ client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some azure instances": { "initialize some azure instances": {
existingState: testAzureState, existingState: testAzureState,
client: &fakeProtoClient{ client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"no state exists": { "no state exists": {
existingState: state.ConstellationState{}, existingState: state.ConstellationState{},
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"no instances to pick one": { "no instances to pick one": {
@ -150,7 +150,7 @@ func TestInitialize(t *testing.T) {
}, },
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"only one instance": { "only one instance": {
@ -160,91 +160,91 @@ func TestInitialize(t *testing.T) {
}, },
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"public key to short": { "public key to short": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")), privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
errExpected: true, errExpected: true,
}, },
"public key to long": { "public key to long": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")), privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
errExpected: true, errExpected: true,
}, },
"public key not base64": { "public key not base64": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: "this is not base64 encoded", privKey: "this is not base64 encoded",
errExpected: true, errExpected: true,
}, },
"fail Connect": { "fail Connect": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{connectErr: someErr}, client: &stubProtoClient{connectErr: someErr},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail Activate": { "fail Activate": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{activateErr: someErr}, client: &stubProtoClient{activateErr: someErr},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient WriteLogStream": { "fail respClient WriteLogStream": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient getKubeconfig": { "fail respClient getKubeconfig": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient getCoordinatorVpnKey": { "fail respClient getCoordinatorVpnKey": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient getClientVpnIp": { "fail respClient getClientVpnIp": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient getOwnerID": { "fail respClient getOwnerID": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail respClient getClusterID": { "fail respClient getClusterID": {
existingState: testEc2State, existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail to wait for required status": { "fail to wait for required status": {
existingState: testGcpState, existingState: testGcpState,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: stubStatusWaiter{waitForAllErr: someErr}, waiter: stubStatusWaiter{waitForAllErr: someErr},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
"fail to create service account": { "fail to create service account": {
@ -254,7 +254,7 @@ func TestInitialize(t *testing.T) {
createErr: someErr, createErr: someErr,
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
errExpected: true, errExpected: true,
}, },
} }
@ -274,8 +274,8 @@ func TestInitialize(t *testing.T) {
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false)) require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag. // Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600)) require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath")) require.NoError(cmd.Flags().Set("privatekey", "privK"))
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second) ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel() defer cancel()
@ -287,7 +287,6 @@ func TestInitialize(t *testing.T) {
} else { } else {
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), "vpnIp") assert.Contains(out.String(), "vpnIp")
assert.Contains(out.String(), "coordKey")
assert.Contains(out.String(), "ownerID") assert.Contains(out.String(), "ownerID")
assert.Contains(out.String(), "clusterID") assert.Contains(out.String(), "clusterID")
} }
@ -345,19 +344,6 @@ func TestIpsToEndpoints(t *testing.T) {
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints) assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
} }
func TestCheckBase64WGKEy(t *testing.T) {
assert := assert.New(t)
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
assert.NoError(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("shortKey")))
assert.Error(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("looooooooooongKeyWithMoreThan32Bytes")))
assert.Error(checkBase64WGKey(key))
key = []byte("noBase 64")
assert.Error(checkBase64WGKey(key))
}
func TestInitCompletion(t *testing.T) { func TestInitCompletion(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
args []string args []string
@ -397,35 +383,24 @@ func TestInitCompletion(t *testing.T) {
} }
} }
func TestReadVpnKey(t *testing.T) { func TestReadOrGenerateVPNKey(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
testKeyA := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))) testKey := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
testKeyB := []byte(base64.StdEncoding.EncodeToString([]byte("anotherWireGuardKeyForTheTesting")))
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write("testKeyA", testKeyA, false)) require.NoError(fileHandler.Write("testKey", testKey, false))
require.NoError(fileHandler.Write("testKeyB", testKeyB, false))
// provide privK privK, pubK, err := readOrGenerateVPNKey(fileHandler, "testKey")
privK, _, err := readVpnKey(fileHandler, "testKeyA", "")
assert.NoError(err) assert.NoError(err)
assert.Equal(testKeyA, privK) assert.Equal(testKey, privK)
assert.NotEmpty(pubK)
// provide pubK
_, pubK, err := readVpnKey(fileHandler, "", "testKeyA")
assert.NoError(err)
assert.Equal(testKeyA, pubK)
// provide both, privK should be used, pubK ignored
privK, pubK, err = readVpnKey(fileHandler, "testKeyA", "testKeyB")
assert.NoError(err)
assert.Equal(testKeyA, privK)
assert.NotEqual(testKeyB, pubK)
// no path provided // no path provided
_, _, err = readVpnKey(fileHandler, "", "") privK, pubK, err = readOrGenerateVPNKey(fileHandler, "")
assert.Error(err) assert.NoError(err)
assert.NotEmpty(privK)
assert.NotEmpty(pubK)
} }
func TestReadOrGeneratedMasterSecret(t *testing.T) { func TestReadOrGeneratedMasterSecret(t *testing.T) {
@ -583,7 +558,7 @@ func TestAutoscaleFlag(t *testing.T) {
{ {
kubeconfig: "kubeconfig", kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp", clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey", coordinatorVpnKey: testKey,
ownerID: "ownerID", ownerID: "ownerID",
clusterID: "clusterID", clusterID: "clusterID",
}, },
@ -596,7 +571,7 @@ func TestAutoscaleFlag(t *testing.T) {
client *stubProtoClient client *stubProtoClient
serviceAccountCreator stubServiceAccountCreator serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter waiter statusWaiter
pubKey string privKey string
}{ }{
"initialize some ec2 instances without autoscale flag": { "initialize some ec2 instances without autoscale flag": {
autoscaleFlag: false, autoscaleFlag: false,
@ -604,8 +579,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some gcp instances without autoscale flag": { "initialize some gcp instances without autoscale flag": {
autoscaleFlag: false, autoscaleFlag: false,
@ -613,8 +588,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some azure instances without autoscale flag": { "initialize some azure instances without autoscale flag": {
autoscaleFlag: false, autoscaleFlag: false,
@ -622,8 +597,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some ec2 instances with autoscale flag": { "initialize some ec2 instances with autoscale flag": {
autoscaleFlag: true, autoscaleFlag: true,
@ -631,8 +606,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some gcp instances with autoscale flag": { "initialize some gcp instances with autoscale flag": {
autoscaleFlag: true, autoscaleFlag: true,
@ -640,8 +615,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
"initialize some azure instances with autoscale flag": { "initialize some azure instances with autoscale flag": {
autoscaleFlag: true, autoscaleFlag: true,
@ -649,8 +624,8 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{ client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps}, respClient: &fakeActivationRespClient{responses: testActivationResps},
}, },
waiter: stubStatusWaiter{}, waiter: stubStatusWaiter{},
pubKey: testKey, privKey: testKey,
}, },
} }
@ -669,8 +644,8 @@ func TestAutoscaleFlag(t *testing.T) {
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false)) require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag. // Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600)) require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath")) require.NoError(cmd.Flags().Set("privatekey", "privK"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag))) require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background() ctx := context.Background()