cli: unify verify/recover endpoint flag

This commit is contained in:
Thomas Tendyck 2022-05-06 13:56:02 +02:00 committed by Thomas Tendyck
parent c9226de9ab
commit 3318126363
12 changed files with 114 additions and 61 deletions

View File

@ -160,7 +160,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
config *config.Config, validators []atls.Validator,
) (activationResult, error) {
err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators)
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, *config.CoordinatorPort), validators)
if err != nil {
return activationResult{}, err
}

View File

@ -9,7 +9,7 @@ import (
)
type protoClient interface {
Connect(ip, port string, validators []atls.Validator) error
Connect(endpoint string, validators []atls.Validator) error
Close() error
GetState(ctx context.Context) (state.State, error)
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)

View File

@ -28,7 +28,7 @@ type stubProtoClient struct {
cloudServiceAccountURI string
}
func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error {
func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true
return c.connectErr
}
@ -103,9 +103,9 @@ type fakeProtoClient struct {
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error {
if ip == "" || port == "" {
return errors.New("ip or port is empty")
func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
if endpoint == "" {
return errors.New("endpoint is empty")
}
if len(validators) == 0 {
return errors.New("validators is empty")

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/base64"
"errors"
"net"
"regexp"
"strings"
@ -31,8 +30,8 @@ func newRecoverCmd() *cobra.Command {
Args: cobra.ExactArgs(0),
RunE: runRecover,
}
cmd.Flags().String("ip", "", "Instance IP address.")
must(cmd.MarkFlagRequired("ip"))
cmd.Flags().StringP("endpoint", "e", "", "Endpoint of the instance. Form: HOST[:PORT]")
must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.")
must(cmd.MarkFlagRequired("disk-uuid"))
cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.")
@ -68,7 +67,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
}
cmd.Print(validators.WarningsIncludeInit())
if err := recoveryClient.Connect(flags.ip, *config.CoordinatorPort, validators.V()); err != nil {
if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
return err
}
@ -86,12 +85,13 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
}
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
ip, err := cmd.Flags().GetString("ip")
endpoint, err := cmd.Flags().GetString("endpoint")
if err != nil {
return recoverFlags{}, err
}
if netIP := net.ParseIP(ip); netIP == nil {
return recoverFlags{}, errors.New("flag '--ip' isn't a valid IP address")
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
if err != nil {
return recoverFlags{}, err
}
diskUUID, err := cmd.Flags().GetString("disk-uuid")
@ -121,7 +121,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
}
return recoverFlags{
ip: ip,
endpoint: endpoint,
diskUUID: diskUUID,
masterSecret: masterSecret,
devConfigPath: devConfigPath,
@ -129,7 +129,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
}
type recoverFlags struct {
ip string
endpoint string
diskUUID string
masterSecret []byte
devConfigPath string

View File

@ -47,7 +47,7 @@ func TestRecover(t *testing.T) {
setupFs func(*require.Assertions) afero.Fs
existingState state.ConstellationState
client *stubRecoveryClient
ipFlag string
endpointFlag string
diskUUIDFlag string
masterSecretFlag string
devConfigFlag string
@ -63,7 +63,7 @@ func TestRecover(t *testing.T) {
},
existingState: validState,
client: &stubRecoveryClient{},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
},
@ -75,7 +75,7 @@ func TestRecover(t *testing.T) {
},
existingState: validState,
client: &stubRecoveryClient{},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
},
@ -87,7 +87,7 @@ func TestRecover(t *testing.T) {
},
existingState: validState,
client: &stubRecoveryClient{},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
},
@ -101,7 +101,7 @@ func TestRecover(t *testing.T) {
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
devConfigFlag: "nonexistent-dev-config",
wantErr: true,
@ -113,7 +113,7 @@ func TestRecover(t *testing.T) {
return fs
},
existingState: validState,
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
stateless: true,
wantErr: true,
@ -125,7 +125,7 @@ func TestRecover(t *testing.T) {
return fs
},
existingState: invalidCSPState,
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
@ -137,7 +137,7 @@ func TestRecover(t *testing.T) {
},
existingState: validState,
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
@ -149,7 +149,7 @@ func TestRecover(t *testing.T) {
},
existingState: validState,
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
ipFlag: "192.0.2.1",
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
@ -165,8 +165,8 @@ func TestRecover(t *testing.T) {
out := &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(&bytes.Buffer{})
if tc.ipFlag != "" {
require.NoError(cmd.Flags().Set("ip", tc.ipFlag))
if tc.endpointFlag != "" {
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
}
if tc.diskUUIDFlag != "" {
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
@ -207,41 +207,41 @@ func TestParseRecoverFlags(t *testing.T) {
wantErr: true,
},
"invalid ip": {
args: []string{"--ip", "invalid", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantErr: true,
},
"invalid disk uuid": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "invalid"},
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
wantErr: true,
},
"invalid master secret path": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
wantErr: true,
},
"minimal args set": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantFlags: recoverFlags{
ip: "192.0.2.1",
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"),
},
},
"all args set": {
args: []string{
"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
"--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path",
},
wantFlags: recoverFlags{
ip: "192.0.2.1",
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"),
devConfigPath: "dev-config-path",
},
},
"uppercase disk-uuid is converted to lowercase": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
wantFlags: recoverFlags{
ip: "192.0.2.1",
endpoint: "192.0.2.1:2",
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
masterSecret: []byte("constellation-master-secret-leng"),
},

View File

@ -8,7 +8,7 @@ import (
)
type recoveryClient interface {
Connect(ip, port string, validators []atls.Validator) error
Connect(endpoint string, validators []atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
io.Closer
}

View File

@ -15,7 +15,7 @@ type stubRecoveryClient struct {
pushStateDiskKeyKey []byte
}
func (c *stubRecoveryClient) Connect(_, _ string, _ []atls.Validator) error {
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true
return c.connectErr
}

View File

@ -3,6 +3,9 @@ package cmd
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
@ -49,3 +52,16 @@ func validInstanceTypeForProvider(insType string, provider cloudprovider.Provide
return fmt.Errorf("%s isn't a valid cloud platform", provider)
}
}
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View File

@ -5,6 +5,7 @@ import (
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsCloudProvider(t *testing.T) {
@ -36,3 +37,54 @@ func TestIsCloudProvider(t *testing.T) {
})
}
}
func TestValidateEndpoint(t *testing.T) {
testCases := map[string]struct {
endpoint string
defaultPort int
wantResult string
wantErr bool
}{
"ip and port": {
endpoint: "192.0.2.1:2",
defaultPort: 3,
wantResult: "192.0.2.1:2",
},
"hostname and port": {
endpoint: "foo:2",
defaultPort: 3,
wantResult: "foo:2",
},
"ip": {
endpoint: "192.0.2.1",
defaultPort: 3,
wantResult: "192.0.2.1:3",
},
"hostname": {
endpoint: "foo",
defaultPort: 3,
wantResult: "foo:3",
},
"invalid endpoint": {
endpoint: "foo:2:2",
defaultPort: 3,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResult, res)
})
}
}

View File

@ -4,9 +4,6 @@ import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/cloudprovider"
@ -69,7 +66,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
cmd.Print(validators.Warnings())
}
if err := protoClient.Connect(flags.nodeHost, flags.nodePort, validators.V()); err != nil {
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
return err
}
if _, err := protoClient.GetState(ctx); err != nil {
@ -100,13 +97,9 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
if err != nil {
return verifyFlags{}, err
}
host, port, err := net.SplitHostPort(endpoint)
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
if err != nil {
if !strings.Contains(err.Error(), "missing port in address") {
return verifyFlags{}, err
}
host = endpoint
port = strconv.Itoa(constants.CoordinatorPort)
return verifyFlags{}, err
}
devConfigPath, err := cmd.Flags().GetString("dev-config")
@ -115,8 +108,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
}
return verifyFlags{
nodeHost: host,
nodePort: port,
endpoint: endpoint,
devConfigPath: devConfigPath,
ownerID: ownerID,
clusterID: clusterID,
@ -124,8 +116,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
}
type verifyFlags struct {
nodeHost string
nodePort string
endpoint string
ownerID string
clusterID string
devConfigPath string

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"io"
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/kms"
@ -26,15 +25,13 @@ type Client struct {
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *Client) Connect(ip, port string, validators []atls.Validator) error {
addr := net.JoinHostPort(ip, port)
func (c *Client) Connect(endpoint string, validators []atls.Validator) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return err
}
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
if err != nil {
return err
}

View File

@ -3,7 +3,6 @@ package proto
import (
"context"
"errors"
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
@ -22,15 +21,13 @@ type KeyClient struct {
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *KeyClient) Connect(ip, port string, validators []atls.Validator) error {
addr := net.JoinHostPort(ip, port)
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return err
}
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
if err != nil {
return err
}