diff --git a/cli/cmd/init.go b/cli/cmd/init.go index 2f5ae10e9..87fba4af3 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -160,7 +160,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config, validators []atls.Validator, ) (activationResult, error) { - err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators) + err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, *config.CoordinatorPort), validators) if err != nil { return activationResult{}, err } diff --git a/cli/cmd/protoclient.go b/cli/cmd/protoclient.go index d03266715..85322a399 100644 --- a/cli/cmd/protoclient.go +++ b/cli/cmd/protoclient.go @@ -9,7 +9,7 @@ import ( ) type protoClient interface { - Connect(ip, port string, validators []atls.Validator) error + Connect(endpoint string, validators []atls.Validator) error Close() error GetState(ctx context.Context) (state.State, error) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) diff --git a/cli/cmd/protoclient_test.go b/cli/cmd/protoclient_test.go index 1b63a5861..7117b63bc 100644 --- a/cli/cmd/protoclient_test.go +++ b/cli/cmd/protoclient_test.go @@ -28,7 +28,7 @@ type stubProtoClient struct { cloudServiceAccountURI string } -func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error { +func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error { c.conn = true return c.connectErr } @@ -103,9 +103,9 @@ type fakeProtoClient struct { respClient proto.ActivationResponseClient } -func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error { - if ip == "" || port == "" { - return errors.New("ip or port is empty") +func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error { + if endpoint == "" { + return errors.New("endpoint is empty") } if len(validators) == 0 { return errors.New("validators is empty") diff --git a/cli/cmd/recover.go b/cli/cmd/recover.go index c26534e03..e515d18a9 100644 --- a/cli/cmd/recover.go +++ b/cli/cmd/recover.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "errors" - "net" "regexp" "strings" @@ -31,8 +30,8 @@ func newRecoverCmd() *cobra.Command { Args: cobra.ExactArgs(0), RunE: runRecover, } - cmd.Flags().String("ip", "", "Instance IP address.") - must(cmd.MarkFlagRequired("ip")) + cmd.Flags().StringP("endpoint", "e", "", "Endpoint of the instance. Form: HOST[:PORT]") + must(cmd.MarkFlagRequired("endpoint")) cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.") must(cmd.MarkFlagRequired("disk-uuid")) cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.") @@ -68,7 +67,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, } cmd.Print(validators.WarningsIncludeInit()) - if err := recoveryClient.Connect(flags.ip, *config.CoordinatorPort, validators.V()); err != nil { + if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil { return err } @@ -86,12 +85,13 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, } func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) { - ip, err := cmd.Flags().GetString("ip") + endpoint, err := cmd.Flags().GetString("endpoint") if err != nil { return recoverFlags{}, err } - if netIP := net.ParseIP(ip); netIP == nil { - return recoverFlags{}, errors.New("flag '--ip' isn't a valid IP address") + endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort) + if err != nil { + return recoverFlags{}, err } diskUUID, err := cmd.Flags().GetString("disk-uuid") @@ -121,7 +121,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla } return recoverFlags{ - ip: ip, + endpoint: endpoint, diskUUID: diskUUID, masterSecret: masterSecret, devConfigPath: devConfigPath, @@ -129,7 +129,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla } type recoverFlags struct { - ip string + endpoint string diskUUID string masterSecret []byte devConfigPath string diff --git a/cli/cmd/recover_test.go b/cli/cmd/recover_test.go index e4ea1d3c5..5a924d5bc 100644 --- a/cli/cmd/recover_test.go +++ b/cli/cmd/recover_test.go @@ -47,7 +47,7 @@ func TestRecover(t *testing.T) { setupFs func(*require.Assertions) afero.Fs existingState state.ConstellationState client *stubRecoveryClient - ipFlag string + endpointFlag string diskUUIDFlag string masterSecretFlag string devConfigFlag string @@ -63,7 +63,7 @@ func TestRecover(t *testing.T) { }, existingState: validState, client: &stubRecoveryClient{}, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32}, }, @@ -75,7 +75,7 @@ func TestRecover(t *testing.T) { }, existingState: validState, client: &stubRecoveryClient{}, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF", wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, }, @@ -87,7 +87,7 @@ func TestRecover(t *testing.T) { }, existingState: validState, client: &stubRecoveryClient{}, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef", wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, }, @@ -101,7 +101,7 @@ func TestRecover(t *testing.T) { require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777)) return fs }, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", devConfigFlag: "nonexistent-dev-config", wantErr: true, @@ -113,7 +113,7 @@ func TestRecover(t *testing.T) { return fs }, existingState: validState, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", stateless: true, wantErr: true, @@ -125,7 +125,7 @@ func TestRecover(t *testing.T) { return fs }, existingState: invalidCSPState, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", wantErr: true, }, @@ -137,7 +137,7 @@ func TestRecover(t *testing.T) { }, existingState: validState, client: &stubRecoveryClient{connectErr: errors.New("connect failed")}, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", wantErr: true, }, @@ -149,7 +149,7 @@ func TestRecover(t *testing.T) { }, existingState: validState, client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")}, - ipFlag: "192.0.2.1", + endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", wantErr: true, }, @@ -165,8 +165,8 @@ func TestRecover(t *testing.T) { out := &bytes.Buffer{} cmd.SetOut(out) cmd.SetErr(&bytes.Buffer{}) - if tc.ipFlag != "" { - require.NoError(cmd.Flags().Set("ip", tc.ipFlag)) + if tc.endpointFlag != "" { + require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag)) } if tc.diskUUIDFlag != "" { require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag)) @@ -207,41 +207,41 @@ func TestParseRecoverFlags(t *testing.T) { wantErr: true, }, "invalid ip": { - args: []string{"--ip", "invalid", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, + args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, wantErr: true, }, "invalid disk uuid": { - args: []string{"--ip", "192.0.2.1", "--disk-uuid", "invalid"}, + args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"}, wantErr: true, }, "invalid master secret path": { - args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"}, + args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"}, wantErr: true, }, "minimal args set": { - args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, + args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, wantFlags: recoverFlags{ - ip: "192.0.2.1", + endpoint: "192.0.2.1:2", diskUUID: "12345678-1234-1234-1234-123456789012", masterSecret: []byte("constellation-master-secret-leng"), }, }, "all args set": { args: []string{ - "--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", + "-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path", }, wantFlags: recoverFlags{ - ip: "192.0.2.1", + endpoint: "192.0.2.1:2", diskUUID: "12345678-1234-1234-1234-123456789012", masterSecret: []byte("constellation-master-secret-leng"), devConfigPath: "dev-config-path", }, }, "uppercase disk-uuid is converted to lowercase": { - args: []string{"--ip", "192.0.2.1", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"}, + args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"}, wantFlags: recoverFlags{ - ip: "192.0.2.1", + endpoint: "192.0.2.1:2", diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef", masterSecret: []byte("constellation-master-secret-leng"), }, diff --git a/cli/cmd/recoveryclient.go b/cli/cmd/recoveryclient.go index 6e9dad1dc..a27eee4b0 100644 --- a/cli/cmd/recoveryclient.go +++ b/cli/cmd/recoveryclient.go @@ -8,7 +8,7 @@ import ( ) type recoveryClient interface { - Connect(ip, port string, validators []atls.Validator) error + Connect(endpoint string, validators []atls.Validator) error PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error io.Closer } diff --git a/cli/cmd/recoveryclient_test.go b/cli/cmd/recoveryclient_test.go index e65fcf333..3e4b49e87 100644 --- a/cli/cmd/recoveryclient_test.go +++ b/cli/cmd/recoveryclient_test.go @@ -15,7 +15,7 @@ type stubRecoveryClient struct { pushStateDiskKeyKey []byte } -func (c *stubRecoveryClient) Connect(_, _ string, _ []atls.Validator) error { +func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error { c.conn = true return c.connectErr } diff --git a/cli/cmd/validargs.go b/cli/cmd/validargs.go index 0134ac53f..9ae6c0370 100644 --- a/cli/cmd/validargs.go +++ b/cli/cmd/validargs.go @@ -3,6 +3,9 @@ package cmd import ( "errors" "fmt" + "net" + "strconv" + "strings" "github.com/edgelesssys/constellation/cli/azure" "github.com/edgelesssys/constellation/cli/cloudprovider" @@ -49,3 +52,16 @@ func validInstanceTypeForProvider(insType string, provider cloudprovider.Provide return fmt.Errorf("%s isn't a valid cloud platform", provider) } } + +func validateEndpoint(endpoint string, defaultPort int) (string, error) { + _, _, err := net.SplitHostPort(endpoint) + if err == nil { + return endpoint, nil + } + + if strings.Contains(err.Error(), "missing port in address") { + return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil + } + + return "", err +} diff --git a/cli/cmd/validargs_test.go b/cli/cmd/validargs_test.go index 508bcfa4c..f9ebd9b5c 100644 --- a/cli/cmd/validargs_test.go +++ b/cli/cmd/validargs_test.go @@ -5,6 +5,7 @@ import ( "github.com/spf13/cobra" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsCloudProvider(t *testing.T) { @@ -36,3 +37,54 @@ func TestIsCloudProvider(t *testing.T) { }) } } + +func TestValidateEndpoint(t *testing.T) { + testCases := map[string]struct { + endpoint string + defaultPort int + wantResult string + wantErr bool + }{ + "ip and port": { + endpoint: "192.0.2.1:2", + defaultPort: 3, + wantResult: "192.0.2.1:2", + }, + "hostname and port": { + endpoint: "foo:2", + defaultPort: 3, + wantResult: "foo:2", + }, + "ip": { + endpoint: "192.0.2.1", + defaultPort: 3, + wantResult: "192.0.2.1:3", + }, + "hostname": { + endpoint: "foo", + defaultPort: 3, + wantResult: "foo:3", + }, + "invalid endpoint": { + endpoint: "foo:2:2", + defaultPort: 3, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + res, err := validateEndpoint(tc.endpoint, tc.defaultPort) + if tc.wantErr { + assert.Error(err) + return + } + + require.NoError(err) + assert.Equal(tc.wantResult, res) + }) + } +} diff --git a/cli/cmd/verify.go b/cli/cmd/verify.go index eb4887b35..07decb4cb 100644 --- a/cli/cmd/verify.go +++ b/cli/cmd/verify.go @@ -4,9 +4,6 @@ import ( "context" "errors" "fmt" - "net" - "strconv" - "strings" "github.com/edgelesssys/constellation/cli/cloud/cloudcmd" "github.com/edgelesssys/constellation/cli/cloudprovider" @@ -69,7 +66,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov cmd.Print(validators.Warnings()) } - if err := protoClient.Connect(flags.nodeHost, flags.nodePort, validators.V()); err != nil { + if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil { return err } if _, err := protoClient.GetState(ctx); err != nil { @@ -100,13 +97,9 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) { if err != nil { return verifyFlags{}, err } - host, port, err := net.SplitHostPort(endpoint) + endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort) if err != nil { - if !strings.Contains(err.Error(), "missing port in address") { - return verifyFlags{}, err - } - host = endpoint - port = strconv.Itoa(constants.CoordinatorPort) + return verifyFlags{}, err } devConfigPath, err := cmd.Flags().GetString("dev-config") @@ -115,8 +108,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) { } return verifyFlags{ - nodeHost: host, - nodePort: port, + endpoint: endpoint, devConfigPath: devConfigPath, ownerID: ownerID, clusterID: clusterID, @@ -124,8 +116,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) { } type verifyFlags struct { - nodeHost string - nodePort string + endpoint string ownerID string clusterID string devConfigPath string diff --git a/cli/proto/client.go b/cli/proto/client.go index 264028e3d..b7f9c748a 100644 --- a/cli/proto/client.go +++ b/cli/proto/client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "io" - "net" "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/kms" @@ -26,15 +25,13 @@ type Client struct { // The connection must be closed using Close(). If connect is // called on a client that already has a connection, the old // connection is closed. -func (c *Client) Connect(ip, port string, validators []atls.Validator) error { - addr := net.JoinHostPort(ip, port) - +func (c *Client) Connect(endpoint string, validators []atls.Validator) error { tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) if err != nil { return err } - conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { return err } diff --git a/cli/proto/recover.go b/cli/proto/recover.go index 366f2855b..dde4cdacf 100644 --- a/cli/proto/recover.go +++ b/cli/proto/recover.go @@ -3,7 +3,6 @@ package proto import ( "context" "errors" - "net" "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/state/keyservice/keyproto" @@ -22,15 +21,13 @@ type KeyClient struct { // The connection must be closed using Close(). If connect is // called on a client that already has a connection, the old // connection is closed. -func (c *KeyClient) Connect(ip, port string, validators []atls.Validator) error { - addr := net.JoinHostPort(ip, port) - +func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error { tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) if err != nil { return err } - conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { return err }