diff --git a/cli/internal/cmd/details.go b/cli/internal/cmd/id.go similarity index 81% rename from cli/internal/cmd/details.go rename to cli/internal/cmd/id.go index 81e04ca51..c2c0aacbb 100644 --- a/cli/internal/cmd/details.go +++ b/cli/internal/cmd/id.go @@ -3,5 +3,5 @@ package cmd type clusterIDsFile struct { ClusterID string OwnerID string - Endpoint string + IP string } diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index 1014d4e60..485cd78fe 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -217,7 +217,7 @@ func writeOutput(resp *initproto.InitResponse, ip string, wr io.Writer, fileHand idFile := clusterIDsFile{ ClusterID: clusterID, OwnerID: ownerID, - Endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.VerifyServiceNodePortGRPC)), + IP: ip, } if err := fileHandler.WriteJSON(constants.ClusterIDsFileName, idFile, file.OptNone); err != nil { return fmt.Errorf("writing Constellation id file: %w", err) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index dcc615a6f..da32d6420 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -207,14 +207,14 @@ func TestWriteOutput(t *testing.T) { expectedIDFile := clusterIDsFile{ ClusterID: clusterID, OwnerID: ownerID, - Endpoint: net.JoinHostPort("ip", strconv.Itoa(constants.VerifyServiceNodePortGRPC)), + IP: "cluster-ip", } var out bytes.Buffer testFs := afero.NewMemMapFs() fileHandler := file.NewHandler(testFs) - err := writeOutput(resp, "ip", &out, fileHandler) + err := writeOutput(resp, "cluster-ip", &out, fileHandler) assert.NoError(err) // assert.Contains(out.String(), ownerID) assert.Contains(out.String(), clusterID) diff --git a/cli/internal/cmd/recover.go b/cli/internal/cmd/recover.go index 60aba2584..de9c04921 100644 --- a/cli/internal/cmd/recover.go +++ b/cli/internal/cmd/recover.go @@ -100,7 +100,7 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) { if err != nil { return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err) } - endpoint, err = validateEndpoint(endpoint, constants.BootstrapperPort) + endpoint, err = addPortIfMissing(endpoint, constants.BootstrapperPort) if err != nil { return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err) } diff --git a/cli/internal/cmd/validargs.go b/cli/internal/cmd/validargs.go index 521f9b6db..4f94d662d 100644 --- a/cli/internal/cmd/validargs.go +++ b/cli/internal/cmd/validargs.go @@ -3,9 +3,6 @@ package cmd import ( "errors" "fmt" - "net" - "strconv" - "strings" "github.com/edgelesssys/constellation/cli/internal/azure" "github.com/edgelesssys/constellation/cli/internal/gcp" @@ -56,20 +53,3 @@ func validInstanceTypeForProvider(cmd *cobra.Command, insType string, provider c return fmt.Errorf("%s isn't a valid cloud platform", provider) } } - -func validateEndpoint(endpoint string, defaultPort int) (string, error) { - if endpoint == "" { - return "", errors.New("endpoint is empty") - } - - _, _, err := net.SplitHostPort(endpoint) - if err == nil { - return endpoint, nil - } - - if strings.Contains(err.Error(), "missing port in address") { - return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil - } - - return "", err -} diff --git a/cli/internal/cmd/validargs_test.go b/cli/internal/cmd/validargs_test.go index 5dd89aea2..508bcfa4c 100644 --- a/cli/internal/cmd/validargs_test.go +++ b/cli/internal/cmd/validargs_test.go @@ -5,7 +5,6 @@ import ( "github.com/spf13/cobra" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestIsCloudProvider(t *testing.T) { @@ -37,59 +36,3 @@ func TestIsCloudProvider(t *testing.T) { }) } } - -func TestValidateEndpoint(t *testing.T) { - testCases := map[string]struct { - endpoint string - defaultPort int - wantResult string - wantErr bool - }{ - "ip and port": { - endpoint: "192.0.2.1:2", - defaultPort: 3, - wantResult: "192.0.2.1:2", - }, - "hostname and port": { - endpoint: "foo:2", - defaultPort: 3, - wantResult: "foo:2", - }, - "ip": { - endpoint: "192.0.2.1", - defaultPort: 3, - wantResult: "192.0.2.1:3", - }, - "hostname": { - endpoint: "foo", - defaultPort: 3, - wantResult: "foo:3", - }, - "empty endpoint": { - endpoint: "", - defaultPort: 3, - wantErr: true, - }, - "invalid endpoint": { - endpoint: "foo:2:2", - defaultPort: 3, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - require := require.New(t) - - res, err := validateEndpoint(tc.endpoint, tc.defaultPort) - if tc.wantErr { - assert.Error(err) - return - } - - require.NoError(err) - assert.Equal(tc.wantResult, res) - }) - } -} diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 3681383de..496541c1f 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -5,8 +5,9 @@ import ( "context" "errors" "fmt" - "io/fs" "net" + "strconv" + "strings" "github.com/edgelesssys/constellation/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/internal/atls" @@ -114,30 +115,30 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err) } - // Get empty values from ID file emptyEndpoint := endpoint == "" emptyIDs := ownerID == "" && clusterID == "" - if emptyEndpoint || emptyIDs { - if details, err := readIds(fileHandler); err == nil { - if emptyEndpoint { - cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName) - endpoint = details.Endpoint - } - if emptyIDs { - cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName) - ownerID = details.OwnerID - clusterID = details.ClusterID - } - } else if !errors.Is(err, fs.ErrNotExist) { - return verifyFlags{}, err + + var idFile clusterIDsFile + if emptyEndpoint || emptyIDs { // Get empty values from ID file + if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &idFile); err != nil { + return verifyFlags{}, fmt.Errorf("reading cluster ID file: %w", err) } } + if emptyEndpoint { + cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName) + endpoint = idFile.IP + } + if emptyIDs { + cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName) + ownerID = idFile.OwnerID + clusterID = idFile.ClusterID + } // Validate if ownerID == "" && clusterID == "" { return verifyFlags{}, errors.New("neither owner-id nor cluster-id provided to verify the cluster") } - endpoint, err = validateEndpoint(endpoint, constants.VerifyServiceNodePortGRPC) + endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC) if err != nil { return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err) } @@ -157,12 +158,21 @@ type verifyFlags struct { configPath string } -func readIds(fileHandler file.Handler) (clusterIDsFile, error) { - det := clusterIDsFile{} - if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &det); err != nil { - return clusterIDsFile{}, fmt.Errorf("reading cluster ids: %w", err) +func addPortIfMissing(endpoint string, defaultPort int) (string, error) { + if endpoint == "" { + return "", errors.New("endpoint is empty") } - return det, nil + + _, _, err := net.SplitHostPort(endpoint) + if err == nil { + return endpoint, nil + } + + if strings.Contains(err.Error(), "missing port in address") { + return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil + } + + return "", err } // verifyCompletion handles the completion of CLI arguments. It is frequently called diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index 1d2ba119c..63e4ca0ab 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -106,8 +106,8 @@ func TestVerify(t *testing.T) { provider: cloudprovider.GCP, ownerIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, - idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"}, - wantEndpoint: "192.0.2.1:1234", + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + wantEndpoint: "192.0.2.1:30081", }, "override endpoint from details file": { setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() }, @@ -115,7 +115,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.2:1234", ownerIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, - idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"}, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, wantEndpoint: "192.0.2.2:1234", }, "invalid endpoint": { @@ -345,3 +345,59 @@ type stubVerifyAPI struct { func (a stubVerifyAPI) GetAttestation(context.Context, *verifyproto.GetAttestationRequest) (*verifyproto.GetAttestationResponse, error) { return a.attestation, a.attestationErr } + +func TestAddPortIfMissing(t *testing.T) { + testCases := map[string]struct { + endpoint string + defaultPort int + wantResult string + wantErr bool + }{ + "ip and port": { + endpoint: "192.0.2.1:2", + defaultPort: 3, + wantResult: "192.0.2.1:2", + }, + "hostname and port": { + endpoint: "foo:2", + defaultPort: 3, + wantResult: "foo:2", + }, + "ip": { + endpoint: "192.0.2.1", + defaultPort: 3, + wantResult: "192.0.2.1:3", + }, + "hostname": { + endpoint: "foo", + defaultPort: 3, + wantResult: "foo:3", + }, + "empty endpoint": { + endpoint: "", + defaultPort: 3, + wantErr: true, + }, + "invalid endpoint": { + endpoint: "foo:2:2", + defaultPort: 3, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + res, err := addPortIfMissing(tc.endpoint, tc.defaultPort) + if tc.wantErr { + assert.Error(err) + return + } + + require.NoError(err) + assert.Equal(tc.wantResult, res) + }) + } +}