mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
cli: unify verify/recover endpoint flag
This commit is contained in:
parent
c9226de9ab
commit
3318126363
@ -160,7 +160,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
|
||||
config *config.Config, validators []atls.Validator,
|
||||
) (activationResult, error) {
|
||||
err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators)
|
||||
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, *config.CoordinatorPort), validators)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type protoClient interface {
|
||||
Connect(ip, port string, validators []atls.Validator) error
|
||||
Connect(endpoint string, validators []atls.Validator) error
|
||||
Close() error
|
||||
GetState(ctx context.Context) (state.State, error)
|
||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||
|
@ -28,7 +28,7 @@ type stubProtoClient struct {
|
||||
cloudServiceAccountURI string
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error {
|
||||
func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
@ -103,9 +103,9 @@ type fakeProtoClient struct {
|
||||
respClient proto.ActivationResponseClient
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error {
|
||||
if ip == "" || port == "" {
|
||||
return errors.New("ip or port is empty")
|
||||
func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
|
||||
if endpoint == "" {
|
||||
return errors.New("endpoint is empty")
|
||||
}
|
||||
if len(validators) == 0 {
|
||||
return errors.New("validators is empty")
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@ -31,8 +30,8 @@ func newRecoverCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: runRecover,
|
||||
}
|
||||
cmd.Flags().String("ip", "", "Instance IP address.")
|
||||
must(cmd.MarkFlagRequired("ip"))
|
||||
cmd.Flags().StringP("endpoint", "e", "", "Endpoint of the instance. Form: HOST[:PORT]")
|
||||
must(cmd.MarkFlagRequired("endpoint"))
|
||||
cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.")
|
||||
must(cmd.MarkFlagRequired("disk-uuid"))
|
||||
cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.")
|
||||
@ -68,7 +67,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
|
||||
}
|
||||
cmd.Print(validators.WarningsIncludeInit())
|
||||
|
||||
if err := recoveryClient.Connect(flags.ip, *config.CoordinatorPort, validators.V()); err != nil {
|
||||
if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -86,12 +85,13 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
|
||||
}
|
||||
|
||||
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
|
||||
ip, err := cmd.Flags().GetString("ip")
|
||||
endpoint, err := cmd.Flags().GetString("endpoint")
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
if netIP := net.ParseIP(ip); netIP == nil {
|
||||
return recoverFlags{}, errors.New("flag '--ip' isn't a valid IP address")
|
||||
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
|
||||
diskUUID, err := cmd.Flags().GetString("disk-uuid")
|
||||
@ -121,7 +121,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
|
||||
}
|
||||
|
||||
return recoverFlags{
|
||||
ip: ip,
|
||||
endpoint: endpoint,
|
||||
diskUUID: diskUUID,
|
||||
masterSecret: masterSecret,
|
||||
devConfigPath: devConfigPath,
|
||||
@ -129,7 +129,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
ip string
|
||||
endpoint string
|
||||
diskUUID string
|
||||
masterSecret []byte
|
||||
devConfigPath string
|
||||
|
@ -47,7 +47,7 @@ func TestRecover(t *testing.T) {
|
||||
setupFs func(*require.Assertions) afero.Fs
|
||||
existingState state.ConstellationState
|
||||
client *stubRecoveryClient
|
||||
ipFlag string
|
||||
endpointFlag string
|
||||
diskUUIDFlag string
|
||||
masterSecretFlag string
|
||||
devConfigFlag string
|
||||
@ -63,7 +63,7 @@ func TestRecover(t *testing.T) {
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
|
||||
},
|
||||
@ -75,7 +75,7 @@ func TestRecover(t *testing.T) {
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
|
||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||
},
|
||||
@ -87,7 +87,7 @@ func TestRecover(t *testing.T) {
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||
},
|
||||
@ -101,7 +101,7 @@ func TestRecover(t *testing.T) {
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
devConfigFlag: "nonexistent-dev-config",
|
||||
wantErr: true,
|
||||
@ -113,7 +113,7 @@ func TestRecover(t *testing.T) {
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
stateless: true,
|
||||
wantErr: true,
|
||||
@ -125,7 +125,7 @@ func TestRecover(t *testing.T) {
|
||||
return fs
|
||||
},
|
||||
existingState: invalidCSPState,
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
@ -137,7 +137,7 @@ func TestRecover(t *testing.T) {
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
@ -149,7 +149,7 @@ func TestRecover(t *testing.T) {
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
|
||||
ipFlag: "192.0.2.1",
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
@ -165,8 +165,8 @@ func TestRecover(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
if tc.ipFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("ip", tc.ipFlag))
|
||||
if tc.endpointFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
|
||||
}
|
||||
if tc.diskUUIDFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
|
||||
@ -207,41 +207,41 @@ func TestParseRecoverFlags(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid ip": {
|
||||
args: []string{"--ip", "invalid", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid disk uuid": {
|
||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "invalid"},
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid master secret path": {
|
||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
"minimal args set": {
|
||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
wantFlags: recoverFlags{
|
||||
ip: "192.0.2.1",
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
},
|
||||
},
|
||||
"all args set": {
|
||||
args: []string{
|
||||
"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
|
||||
"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
|
||||
"--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path",
|
||||
},
|
||||
wantFlags: recoverFlags{
|
||||
ip: "192.0.2.1",
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
devConfigPath: "dev-config-path",
|
||||
},
|
||||
},
|
||||
"uppercase disk-uuid is converted to lowercase": {
|
||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
||||
wantFlags: recoverFlags{
|
||||
ip: "192.0.2.1",
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
},
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
type recoveryClient interface {
|
||||
Connect(ip, port string, validators []atls.Validator) error
|
||||
Connect(endpoint string, validators []atls.Validator) error
|
||||
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
|
||||
io.Closer
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ type stubRecoveryClient struct {
|
||||
pushStateDiskKeyKey []byte
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Connect(_, _ string, _ []atls.Validator) error {
|
||||
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
@ -3,6 +3,9 @@ package cmd
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
@ -49,3 +52,16 @@ func validInstanceTypeForProvider(insType string, provider cloudprovider.Provide
|
||||
return fmt.Errorf("%s isn't a valid cloud platform", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudProvider(t *testing.T) {
|
||||
@ -36,3 +37,54 @@ func TestIsCloudProvider(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEndpoint(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
endpoint string
|
||||
defaultPort int
|
||||
wantResult string
|
||||
wantErr bool
|
||||
}{
|
||||
"ip and port": {
|
||||
endpoint: "192.0.2.1:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:2",
|
||||
},
|
||||
"hostname and port": {
|
||||
endpoint: "foo:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:2",
|
||||
},
|
||||
"ip": {
|
||||
endpoint: "192.0.2.1",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:3",
|
||||
},
|
||||
"hostname": {
|
||||
endpoint: "foo",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:3",
|
||||
},
|
||||
"invalid endpoint": {
|
||||
endpoint: "foo:2:2",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResult, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
@ -69,7 +66,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
|
||||
cmd.Print(validators.Warnings())
|
||||
}
|
||||
|
||||
if err := protoClient.Connect(flags.nodeHost, flags.nodePort, validators.V()); err != nil {
|
||||
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := protoClient.GetState(ctx); err != nil {
|
||||
@ -100,13 +97,9 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "missing port in address") {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
host = endpoint
|
||||
port = strconv.Itoa(constants.CoordinatorPort)
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
|
||||
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
||||
@ -115,8 +108,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
||||
}
|
||||
|
||||
return verifyFlags{
|
||||
nodeHost: host,
|
||||
nodePort: port,
|
||||
endpoint: endpoint,
|
||||
devConfigPath: devConfigPath,
|
||||
ownerID: ownerID,
|
||||
clusterID: clusterID,
|
||||
@ -124,8 +116,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
||||
}
|
||||
|
||||
type verifyFlags struct {
|
||||
nodeHost string
|
||||
nodePort string
|
||||
endpoint string
|
||||
ownerID string
|
||||
clusterID string
|
||||
devConfigPath string
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||
@ -26,15 +25,13 @@ type Client struct {
|
||||
// The connection must be closed using Close(). If connect is
|
||||
// called on a client that already has a connection, the old
|
||||
// connection is closed.
|
||||
func (c *Client) Connect(ip, port string, validators []atls.Validator) error {
|
||||
addr := net.JoinHostPort(ip, port)
|
||||
|
||||
func (c *Client) Connect(endpoint string, validators []atls.Validator) error {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package proto
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
@ -22,15 +21,13 @@ type KeyClient struct {
|
||||
// The connection must be closed using Close(). If connect is
|
||||
// called on a client that already has a connection, the old
|
||||
// connection is closed.
|
||||
func (c *KeyClient) Connect(ip, port string, validators []atls.Validator) error {
|
||||
addr := net.JoinHostPort(ip, port)
|
||||
|
||||
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user