mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
cli: unify verify/recover endpoint flag
This commit is contained in:
parent
c9226de9ab
commit
3318126363
@ -160,7 +160,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
|||||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
|
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
|
||||||
config *config.Config, validators []atls.Validator,
|
config *config.Config, validators []atls.Validator,
|
||||||
) (activationResult, error) {
|
) (activationResult, error) {
|
||||||
err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators)
|
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, *config.CoordinatorPort), validators)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return activationResult{}, err
|
return activationResult{}, err
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type protoClient interface {
|
type protoClient interface {
|
||||||
Connect(ip, port string, validators []atls.Validator) error
|
Connect(endpoint string, validators []atls.Validator) error
|
||||||
Close() error
|
Close() error
|
||||||
GetState(ctx context.Context) (state.State, error)
|
GetState(ctx context.Context) (state.State, error)
|
||||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||||
|
@ -28,7 +28,7 @@ type stubProtoClient struct {
|
|||||||
cloudServiceAccountURI string
|
cloudServiceAccountURI string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error {
|
func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
|
||||||
c.conn = true
|
c.conn = true
|
||||||
return c.connectErr
|
return c.connectErr
|
||||||
}
|
}
|
||||||
@ -103,9 +103,9 @@ type fakeProtoClient struct {
|
|||||||
respClient proto.ActivationResponseClient
|
respClient proto.ActivationResponseClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error {
|
func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
|
||||||
if ip == "" || port == "" {
|
if endpoint == "" {
|
||||||
return errors.New("ip or port is empty")
|
return errors.New("endpoint is empty")
|
||||||
}
|
}
|
||||||
if len(validators) == 0 {
|
if len(validators) == 0 {
|
||||||
return errors.New("validators is empty")
|
return errors.New("validators is empty")
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -31,8 +30,8 @@ func newRecoverCmd() *cobra.Command {
|
|||||||
Args: cobra.ExactArgs(0),
|
Args: cobra.ExactArgs(0),
|
||||||
RunE: runRecover,
|
RunE: runRecover,
|
||||||
}
|
}
|
||||||
cmd.Flags().String("ip", "", "Instance IP address.")
|
cmd.Flags().StringP("endpoint", "e", "", "Endpoint of the instance. Form: HOST[:PORT]")
|
||||||
must(cmd.MarkFlagRequired("ip"))
|
must(cmd.MarkFlagRequired("endpoint"))
|
||||||
cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.")
|
cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.")
|
||||||
must(cmd.MarkFlagRequired("disk-uuid"))
|
must(cmd.MarkFlagRequired("disk-uuid"))
|
||||||
cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.")
|
cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.")
|
||||||
@ -68,7 +67,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
|
|||||||
}
|
}
|
||||||
cmd.Print(validators.WarningsIncludeInit())
|
cmd.Print(validators.WarningsIncludeInit())
|
||||||
|
|
||||||
if err := recoveryClient.Connect(flags.ip, *config.CoordinatorPort, validators.V()); err != nil {
|
if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,12 +85,13 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
|
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
|
||||||
ip, err := cmd.Flags().GetString("ip")
|
endpoint, err := cmd.Flags().GetString("endpoint")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return recoverFlags{}, err
|
return recoverFlags{}, err
|
||||||
}
|
}
|
||||||
if netIP := net.ParseIP(ip); netIP == nil {
|
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||||
return recoverFlags{}, errors.New("flag '--ip' isn't a valid IP address")
|
if err != nil {
|
||||||
|
return recoverFlags{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
diskUUID, err := cmd.Flags().GetString("disk-uuid")
|
diskUUID, err := cmd.Flags().GetString("disk-uuid")
|
||||||
@ -121,7 +121,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
|
|||||||
}
|
}
|
||||||
|
|
||||||
return recoverFlags{
|
return recoverFlags{
|
||||||
ip: ip,
|
endpoint: endpoint,
|
||||||
diskUUID: diskUUID,
|
diskUUID: diskUUID,
|
||||||
masterSecret: masterSecret,
|
masterSecret: masterSecret,
|
||||||
devConfigPath: devConfigPath,
|
devConfigPath: devConfigPath,
|
||||||
@ -129,7 +129,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
|
|||||||
}
|
}
|
||||||
|
|
||||||
type recoverFlags struct {
|
type recoverFlags struct {
|
||||||
ip string
|
endpoint string
|
||||||
diskUUID string
|
diskUUID string
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
devConfigPath string
|
devConfigPath string
|
||||||
|
@ -47,7 +47,7 @@ func TestRecover(t *testing.T) {
|
|||||||
setupFs func(*require.Assertions) afero.Fs
|
setupFs func(*require.Assertions) afero.Fs
|
||||||
existingState state.ConstellationState
|
existingState state.ConstellationState
|
||||||
client *stubRecoveryClient
|
client *stubRecoveryClient
|
||||||
ipFlag string
|
endpointFlag string
|
||||||
diskUUIDFlag string
|
diskUUIDFlag string
|
||||||
masterSecretFlag string
|
masterSecretFlag string
|
||||||
devConfigFlag string
|
devConfigFlag string
|
||||||
@ -63,7 +63,7 @@ func TestRecover(t *testing.T) {
|
|||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
client: &stubRecoveryClient{},
|
client: &stubRecoveryClient{},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
|
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
|
||||||
},
|
},
|
||||||
@ -75,7 +75,7 @@ func TestRecover(t *testing.T) {
|
|||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
client: &stubRecoveryClient{},
|
client: &stubRecoveryClient{},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
|
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
|
||||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||||
},
|
},
|
||||||
@ -87,7 +87,7 @@ func TestRecover(t *testing.T) {
|
|||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
client: &stubRecoveryClient{},
|
client: &stubRecoveryClient{},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||||
},
|
},
|
||||||
@ -101,7 +101,7 @@ func TestRecover(t *testing.T) {
|
|||||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||||
return fs
|
return fs
|
||||||
},
|
},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
devConfigFlag: "nonexistent-dev-config",
|
devConfigFlag: "nonexistent-dev-config",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
@ -113,7 +113,7 @@ func TestRecover(t *testing.T) {
|
|||||||
return fs
|
return fs
|
||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
stateless: true,
|
stateless: true,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
@ -125,7 +125,7 @@ func TestRecover(t *testing.T) {
|
|||||||
return fs
|
return fs
|
||||||
},
|
},
|
||||||
existingState: invalidCSPState,
|
existingState: invalidCSPState,
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -137,7 +137,7 @@ func TestRecover(t *testing.T) {
|
|||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
|
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -149,7 +149,7 @@ func TestRecover(t *testing.T) {
|
|||||||
},
|
},
|
||||||
existingState: validState,
|
existingState: validState,
|
||||||
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
|
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
|
||||||
ipFlag: "192.0.2.1",
|
endpointFlag: "192.0.2.1",
|
||||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -165,8 +165,8 @@ func TestRecover(t *testing.T) {
|
|||||||
out := &bytes.Buffer{}
|
out := &bytes.Buffer{}
|
||||||
cmd.SetOut(out)
|
cmd.SetOut(out)
|
||||||
cmd.SetErr(&bytes.Buffer{})
|
cmd.SetErr(&bytes.Buffer{})
|
||||||
if tc.ipFlag != "" {
|
if tc.endpointFlag != "" {
|
||||||
require.NoError(cmd.Flags().Set("ip", tc.ipFlag))
|
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
|
||||||
}
|
}
|
||||||
if tc.diskUUIDFlag != "" {
|
if tc.diskUUIDFlag != "" {
|
||||||
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
|
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
|
||||||
@ -207,41 +207,41 @@ func TestParseRecoverFlags(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"invalid ip": {
|
"invalid ip": {
|
||||||
args: []string{"--ip", "invalid", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"invalid disk uuid": {
|
"invalid disk uuid": {
|
||||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "invalid"},
|
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"invalid master secret path": {
|
"invalid master secret path": {
|
||||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
|
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"minimal args set": {
|
"minimal args set": {
|
||||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||||
wantFlags: recoverFlags{
|
wantFlags: recoverFlags{
|
||||||
ip: "192.0.2.1",
|
endpoint: "192.0.2.1:2",
|
||||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||||
masterSecret: []byte("constellation-master-secret-leng"),
|
masterSecret: []byte("constellation-master-secret-leng"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"all args set": {
|
"all args set": {
|
||||||
args: []string{
|
args: []string{
|
||||||
"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
|
"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
|
||||||
"--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path",
|
"--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path",
|
||||||
},
|
},
|
||||||
wantFlags: recoverFlags{
|
wantFlags: recoverFlags{
|
||||||
ip: "192.0.2.1",
|
endpoint: "192.0.2.1:2",
|
||||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||||
masterSecret: []byte("constellation-master-secret-leng"),
|
masterSecret: []byte("constellation-master-secret-leng"),
|
||||||
devConfigPath: "dev-config-path",
|
devConfigPath: "dev-config-path",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"uppercase disk-uuid is converted to lowercase": {
|
"uppercase disk-uuid is converted to lowercase": {
|
||||||
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
||||||
wantFlags: recoverFlags{
|
wantFlags: recoverFlags{
|
||||||
ip: "192.0.2.1",
|
endpoint: "192.0.2.1:2",
|
||||||
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||||
masterSecret: []byte("constellation-master-secret-leng"),
|
masterSecret: []byte("constellation-master-secret-leng"),
|
||||||
},
|
},
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type recoveryClient interface {
|
type recoveryClient interface {
|
||||||
Connect(ip, port string, validators []atls.Validator) error
|
Connect(endpoint string, validators []atls.Validator) error
|
||||||
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
|
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ type stubRecoveryClient struct {
|
|||||||
pushStateDiskKeyKey []byte
|
pushStateDiskKeyKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubRecoveryClient) Connect(_, _ string, _ []atls.Validator) error {
|
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
|
||||||
c.conn = true
|
c.conn = true
|
||||||
return c.connectErr
|
return c.connectErr
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/cli/azure"
|
"github.com/edgelesssys/constellation/cli/azure"
|
||||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||||
@ -49,3 +52,16 @@ func validInstanceTypeForProvider(insType string, provider cloudprovider.Provide
|
|||||||
return fmt.Errorf("%s isn't a valid cloud platform", provider)
|
return fmt.Errorf("%s isn't a valid cloud platform", provider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
|
||||||
|
_, _, err := net.SplitHostPort(endpoint)
|
||||||
|
if err == nil {
|
||||||
|
return endpoint, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(err.Error(), "missing port in address") {
|
||||||
|
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsCloudProvider(t *testing.T) {
|
func TestIsCloudProvider(t *testing.T) {
|
||||||
@ -36,3 +37,54 @@ func TestIsCloudProvider(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateEndpoint(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
endpoint string
|
||||||
|
defaultPort int
|
||||||
|
wantResult string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"ip and port": {
|
||||||
|
endpoint: "192.0.2.1:2",
|
||||||
|
defaultPort: 3,
|
||||||
|
wantResult: "192.0.2.1:2",
|
||||||
|
},
|
||||||
|
"hostname and port": {
|
||||||
|
endpoint: "foo:2",
|
||||||
|
defaultPort: 3,
|
||||||
|
wantResult: "foo:2",
|
||||||
|
},
|
||||||
|
"ip": {
|
||||||
|
endpoint: "192.0.2.1",
|
||||||
|
defaultPort: 3,
|
||||||
|
wantResult: "192.0.2.1:3",
|
||||||
|
},
|
||||||
|
"hostname": {
|
||||||
|
endpoint: "foo",
|
||||||
|
defaultPort: 3,
|
||||||
|
wantResult: "foo:3",
|
||||||
|
},
|
||||||
|
"invalid endpoint": {
|
||||||
|
endpoint: "foo:2:2",
|
||||||
|
defaultPort: 3,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(err)
|
||||||
|
assert.Equal(tc.wantResult, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,9 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||||
@ -69,7 +66,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
|
|||||||
cmd.Print(validators.Warnings())
|
cmd.Print(validators.Warnings())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := protoClient.Connect(flags.nodeHost, flags.nodePort, validators.V()); err != nil {
|
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := protoClient.GetState(ctx); err != nil {
|
if _, err := protoClient.GetState(ctx); err != nil {
|
||||||
@ -100,13 +97,9 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return verifyFlags{}, err
|
return verifyFlags{}, err
|
||||||
}
|
}
|
||||||
host, port, err := net.SplitHostPort(endpoint)
|
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.Contains(err.Error(), "missing port in address") {
|
return verifyFlags{}, err
|
||||||
return verifyFlags{}, err
|
|
||||||
}
|
|
||||||
host = endpoint
|
|
||||||
port = strconv.Itoa(constants.CoordinatorPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
||||||
@ -115,8 +108,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return verifyFlags{
|
return verifyFlags{
|
||||||
nodeHost: host,
|
endpoint: endpoint,
|
||||||
nodePort: port,
|
|
||||||
devConfigPath: devConfigPath,
|
devConfigPath: devConfigPath,
|
||||||
ownerID: ownerID,
|
ownerID: ownerID,
|
||||||
clusterID: clusterID,
|
clusterID: clusterID,
|
||||||
@ -124,8 +116,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type verifyFlags struct {
|
type verifyFlags struct {
|
||||||
nodeHost string
|
endpoint string
|
||||||
nodePort string
|
|
||||||
ownerID string
|
ownerID string
|
||||||
clusterID string
|
clusterID string
|
||||||
devConfigPath string
|
devConfigPath string
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||||
@ -26,15 +25,13 @@ type Client struct {
|
|||||||
// The connection must be closed using Close(). If connect is
|
// The connection must be closed using Close(). If connect is
|
||||||
// called on a client that already has a connection, the old
|
// called on a client that already has a connection, the old
|
||||||
// connection is closed.
|
// connection is closed.
|
||||||
func (c *Client) Connect(ip, port string, validators []atls.Validator) error {
|
func (c *Client) Connect(endpoint string, validators []atls.Validator) error {
|
||||||
addr := net.JoinHostPort(ip, port)
|
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package proto
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||||
@ -22,15 +21,13 @@ type KeyClient struct {
|
|||||||
// The connection must be closed using Close(). If connect is
|
// The connection must be closed using Close(). If connect is
|
||||||
// called on a client that already has a connection, the old
|
// called on a client that already has a connection, the old
|
||||||
// connection is closed.
|
// connection is closed.
|
||||||
func (c *KeyClient) Connect(ip, port string, validators []atls.Validator) error {
|
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
|
||||||
addr := net.JoinHostPort(ip, port)
|
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user