cli: unify verify/recover endpoint flag

This commit is contained in:
Thomas Tendyck 2022-05-06 13:56:02 +02:00 committed by Thomas Tendyck
parent c9226de9ab
commit 3318126363
12 changed files with 114 additions and 61 deletions

View File

@ -160,7 +160,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
config *config.Config, validators []atls.Validator, config *config.Config, validators []atls.Validator,
) (activationResult, error) { ) (activationResult, error) {
err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators) err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, *config.CoordinatorPort), validators)
if err != nil { if err != nil {
return activationResult{}, err return activationResult{}, err
} }

View File

@ -9,7 +9,7 @@ import (
) )
type protoClient interface { type protoClient interface {
Connect(ip, port string, validators []atls.Validator) error Connect(endpoint string, validators []atls.Validator) error
Close() error Close() error
GetState(ctx context.Context) (state.State, error) GetState(ctx context.Context) (state.State, error)
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)

View File

@ -28,7 +28,7 @@ type stubProtoClient struct {
cloudServiceAccountURI string cloudServiceAccountURI string
} }
func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error { func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true c.conn = true
return c.connectErr return c.connectErr
} }
@ -103,9 +103,9 @@ type fakeProtoClient struct {
respClient proto.ActivationResponseClient respClient proto.ActivationResponseClient
} }
func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error { func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
if ip == "" || port == "" { if endpoint == "" {
return errors.New("ip or port is empty") return errors.New("endpoint is empty")
} }
if len(validators) == 0 { if len(validators) == 0 {
return errors.New("validators is empty") return errors.New("validators is empty")

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net"
"regexp" "regexp"
"strings" "strings"
@ -31,8 +30,8 @@ func newRecoverCmd() *cobra.Command {
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
RunE: runRecover, RunE: runRecover,
} }
cmd.Flags().String("ip", "", "Instance IP address.") cmd.Flags().StringP("endpoint", "e", "", "Endpoint of the instance. Form: HOST[:PORT]")
must(cmd.MarkFlagRequired("ip")) must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.") cmd.Flags().String("disk-uuid", "", "Disk UUID of the encrypted state disk.")
must(cmd.MarkFlagRequired("disk-uuid")) must(cmd.MarkFlagRequired("disk-uuid"))
cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.") cmd.Flags().String("master-secret", "", "Path to base64 encoded master secret.")
@ -68,7 +67,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
} }
cmd.Print(validators.WarningsIncludeInit()) cmd.Print(validators.WarningsIncludeInit())
if err := recoveryClient.Connect(flags.ip, *config.CoordinatorPort, validators.V()); err != nil { if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
return err return err
} }
@ -86,12 +85,13 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
} }
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) { func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
ip, err := cmd.Flags().GetString("ip") endpoint, err := cmd.Flags().GetString("endpoint")
if err != nil { if err != nil {
return recoverFlags{}, err return recoverFlags{}, err
} }
if netIP := net.ParseIP(ip); netIP == nil { endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
return recoverFlags{}, errors.New("flag '--ip' isn't a valid IP address") if err != nil {
return recoverFlags{}, err
} }
diskUUID, err := cmd.Flags().GetString("disk-uuid") diskUUID, err := cmd.Flags().GetString("disk-uuid")
@ -121,7 +121,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
} }
return recoverFlags{ return recoverFlags{
ip: ip, endpoint: endpoint,
diskUUID: diskUUID, diskUUID: diskUUID,
masterSecret: masterSecret, masterSecret: masterSecret,
devConfigPath: devConfigPath, devConfigPath: devConfigPath,
@ -129,7 +129,7 @@ func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFla
} }
type recoverFlags struct { type recoverFlags struct {
ip string endpoint string
diskUUID string diskUUID string
masterSecret []byte masterSecret []byte
devConfigPath string devConfigPath string

View File

@ -47,7 +47,7 @@ func TestRecover(t *testing.T) {
setupFs func(*require.Assertions) afero.Fs setupFs func(*require.Assertions) afero.Fs
existingState state.ConstellationState existingState state.ConstellationState
client *stubRecoveryClient client *stubRecoveryClient
ipFlag string endpointFlag string
diskUUIDFlag string diskUUIDFlag string
masterSecretFlag string masterSecretFlag string
devConfigFlag string devConfigFlag string
@ -63,7 +63,7 @@ func TestRecover(t *testing.T) {
}, },
existingState: validState, existingState: validState,
client: &stubRecoveryClient{}, client: &stubRecoveryClient{},
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32}, wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
}, },
@ -75,7 +75,7 @@ func TestRecover(t *testing.T) {
}, },
existingState: validState, existingState: validState,
client: &stubRecoveryClient{}, client: &stubRecoveryClient{},
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF", diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
}, },
@ -87,7 +87,7 @@ func TestRecover(t *testing.T) {
}, },
existingState: validState, existingState: validState,
client: &stubRecoveryClient{}, client: &stubRecoveryClient{},
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef", diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
}, },
@ -101,7 +101,7 @@ func TestRecover(t *testing.T) {
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777)) require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs return fs
}, },
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
devConfigFlag: "nonexistent-dev-config", devConfigFlag: "nonexistent-dev-config",
wantErr: true, wantErr: true,
@ -113,7 +113,7 @@ func TestRecover(t *testing.T) {
return fs return fs
}, },
existingState: validState, existingState: validState,
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
stateless: true, stateless: true,
wantErr: true, wantErr: true,
@ -125,7 +125,7 @@ func TestRecover(t *testing.T) {
return fs return fs
}, },
existingState: invalidCSPState, existingState: invalidCSPState,
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true, wantErr: true,
}, },
@ -137,7 +137,7 @@ func TestRecover(t *testing.T) {
}, },
existingState: validState, existingState: validState,
client: &stubRecoveryClient{connectErr: errors.New("connect failed")}, client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true, wantErr: true,
}, },
@ -149,7 +149,7 @@ func TestRecover(t *testing.T) {
}, },
existingState: validState, existingState: validState,
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")}, client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
ipFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000", diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true, wantErr: true,
}, },
@ -165,8 +165,8 @@ func TestRecover(t *testing.T) {
out := &bytes.Buffer{} out := &bytes.Buffer{}
cmd.SetOut(out) cmd.SetOut(out)
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
if tc.ipFlag != "" { if tc.endpointFlag != "" {
require.NoError(cmd.Flags().Set("ip", tc.ipFlag)) require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
} }
if tc.diskUUIDFlag != "" { if tc.diskUUIDFlag != "" {
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag)) require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
@ -207,41 +207,41 @@ func TestParseRecoverFlags(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"invalid ip": { "invalid ip": {
args: []string{"--ip", "invalid", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantErr: true, wantErr: true,
}, },
"invalid disk uuid": { "invalid disk uuid": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "invalid"}, args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
wantErr: true, wantErr: true,
}, },
"invalid master secret path": { "invalid master secret path": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"}, args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
wantErr: true, wantErr: true,
}, },
"minimal args set": { "minimal args set": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantFlags: recoverFlags{ wantFlags: recoverFlags{
ip: "192.0.2.1", endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012", diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"), masterSecret: []byte("constellation-master-secret-leng"),
}, },
}, },
"all args set": { "all args set": {
args: []string{ args: []string{
"--ip", "192.0.2.1", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
"--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path", "--master-secret", "constellation-mastersecret.base64", "--dev-config", "dev-config-path",
}, },
wantFlags: recoverFlags{ wantFlags: recoverFlags{
ip: "192.0.2.1", endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012", diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"), masterSecret: []byte("constellation-master-secret-leng"),
devConfigPath: "dev-config-path", devConfigPath: "dev-config-path",
}, },
}, },
"uppercase disk-uuid is converted to lowercase": { "uppercase disk-uuid is converted to lowercase": {
args: []string{"--ip", "192.0.2.1", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"}, args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
wantFlags: recoverFlags{ wantFlags: recoverFlags{
ip: "192.0.2.1", endpoint: "192.0.2.1:2",
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef", diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
masterSecret: []byte("constellation-master-secret-leng"), masterSecret: []byte("constellation-master-secret-leng"),
}, },

View File

@ -8,7 +8,7 @@ import (
) )
type recoveryClient interface { type recoveryClient interface {
Connect(ip, port string, validators []atls.Validator) error Connect(endpoint string, validators []atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
io.Closer io.Closer
} }

View File

@ -15,7 +15,7 @@ type stubRecoveryClient struct {
pushStateDiskKeyKey []byte pushStateDiskKeyKey []byte
} }
func (c *stubRecoveryClient) Connect(_, _ string, _ []atls.Validator) error { func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true c.conn = true
return c.connectErr return c.connectErr
} }

View File

@ -3,6 +3,9 @@ package cmd
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure" "github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider" "github.com/edgelesssys/constellation/cli/cloudprovider"
@ -49,3 +52,16 @@ func validInstanceTypeForProvider(insType string, provider cloudprovider.Provide
return fmt.Errorf("%s isn't a valid cloud platform", provider) return fmt.Errorf("%s isn't a valid cloud platform", provider)
} }
} }
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View File

@ -5,6 +5,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestIsCloudProvider(t *testing.T) { func TestIsCloudProvider(t *testing.T) {
@ -36,3 +37,54 @@ func TestIsCloudProvider(t *testing.T) {
}) })
} }
} }
func TestValidateEndpoint(t *testing.T) {
testCases := map[string]struct {
endpoint string
defaultPort int
wantResult string
wantErr bool
}{
"ip and port": {
endpoint: "192.0.2.1:2",
defaultPort: 3,
wantResult: "192.0.2.1:2",
},
"hostname and port": {
endpoint: "foo:2",
defaultPort: 3,
wantResult: "foo:2",
},
"ip": {
endpoint: "192.0.2.1",
defaultPort: 3,
wantResult: "192.0.2.1:3",
},
"hostname": {
endpoint: "foo",
defaultPort: 3,
wantResult: "foo:3",
},
"invalid endpoint": {
endpoint: "foo:2:2",
defaultPort: 3,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResult, res)
})
}
}

View File

@ -4,9 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd" "github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/cloudprovider" "github.com/edgelesssys/constellation/cli/cloudprovider"
@ -69,7 +66,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
cmd.Print(validators.Warnings()) cmd.Print(validators.Warnings())
} }
if err := protoClient.Connect(flags.nodeHost, flags.nodePort, validators.V()); err != nil { if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
return err return err
} }
if _, err := protoClient.GetState(ctx); err != nil { if _, err := protoClient.GetState(ctx); err != nil {
@ -100,13 +97,9 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
if err != nil { if err != nil {
return verifyFlags{}, err return verifyFlags{}, err
} }
host, port, err := net.SplitHostPort(endpoint) endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "missing port in address") { return verifyFlags{}, err
return verifyFlags{}, err
}
host = endpoint
port = strconv.Itoa(constants.CoordinatorPort)
} }
devConfigPath, err := cmd.Flags().GetString("dev-config") devConfigPath, err := cmd.Flags().GetString("dev-config")
@ -115,8 +108,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
} }
return verifyFlags{ return verifyFlags{
nodeHost: host, endpoint: endpoint,
nodePort: port,
devConfigPath: devConfigPath, devConfigPath: devConfigPath,
ownerID: ownerID, ownerID: ownerID,
clusterID: clusterID, clusterID: clusterID,
@ -124,8 +116,7 @@ func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
} }
type verifyFlags struct { type verifyFlags struct {
nodeHost string endpoint string
nodePort string
ownerID string ownerID string
clusterID string clusterID string
devConfigPath string devConfigPath string

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"net"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/kms" "github.com/edgelesssys/constellation/coordinator/kms"
@ -26,15 +25,13 @@ type Client struct {
// The connection must be closed using Close(). If connect is // The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *Client) Connect(ip, port string, validators []atls.Validator) error { func (c *Client) Connect(endpoint string, validators []atls.Validator) error {
addr := net.JoinHostPort(ip, port)
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil { if err != nil {
return err return err
} }
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,7 +3,6 @@ package proto
import ( import (
"context" "context"
"errors" "errors"
"net"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
@ -22,15 +21,13 @@ type KeyClient struct {
// The connection must be closed using Close(). If connect is // The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *KeyClient) Connect(ip, port string, validators []atls.Validator) error { func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
addr := net.JoinHostPort(ip, port)
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil { if err != nil {
return err return err
} }
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
if err != nil { if err != nil {
return err return err
} }