Refactor id file interaction

* Use IP instead of endpoint in clusterIDsFile
* Move and rename validateEnpoint to addPortIfMissing
* Refactor clusterIDsFile handling in verify cmd
This commit is contained in:
katexochen 2022-07-29 08:24:13 +02:00 committed by Paul Meyer
parent c2faa20d6e
commit 7bbcc564bb
8 changed files with 95 additions and 106 deletions

View File

@ -3,5 +3,5 @@ package cmd
type clusterIDsFile struct {
ClusterID string
OwnerID string
Endpoint string
IP string
}

View File

@ -217,7 +217,7 @@ func writeOutput(resp *initproto.InitResponse, ip string, wr io.Writer, fileHand
idFile := clusterIDsFile{
ClusterID: clusterID,
OwnerID: ownerID,
Endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.VerifyServiceNodePortGRPC)),
IP: ip,
}
if err := fileHandler.WriteJSON(constants.ClusterIDsFileName, idFile, file.OptNone); err != nil {
return fmt.Errorf("writing Constellation id file: %w", err)

View File

@ -207,14 +207,14 @@ func TestWriteOutput(t *testing.T) {
expectedIDFile := clusterIDsFile{
ClusterID: clusterID,
OwnerID: ownerID,
Endpoint: net.JoinHostPort("ip", strconv.Itoa(constants.VerifyServiceNodePortGRPC)),
IP: "cluster-ip",
}
var out bytes.Buffer
testFs := afero.NewMemMapFs()
fileHandler := file.NewHandler(testFs)
err := writeOutput(resp, "ip", &out, fileHandler)
err := writeOutput(resp, "cluster-ip", &out, fileHandler)
assert.NoError(err)
// assert.Contains(out.String(), ownerID)
assert.Contains(out.String(), clusterID)

View File

@ -100,7 +100,7 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
}
endpoint, err = validateEndpoint(endpoint, constants.BootstrapperPort)
endpoint, err = addPortIfMissing(endpoint, constants.BootstrapperPort)
if err != nil {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
}

View File

@ -3,9 +3,6 @@ package cmd
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/gcp"
@ -56,20 +53,3 @@ func validInstanceTypeForProvider(cmd *cobra.Command, insType string, provider c
return fmt.Errorf("%s isn't a valid cloud platform", provider)
}
}
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View File

@ -5,7 +5,6 @@ import (
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsCloudProvider(t *testing.T) {
@ -37,59 +36,3 @@ func TestIsCloudProvider(t *testing.T) {
})
}
}
func TestValidateEndpoint(t *testing.T) {
testCases := map[string]struct {
endpoint string
defaultPort int
wantResult string
wantErr bool
}{
"ip and port": {
endpoint: "192.0.2.1:2",
defaultPort: 3,
wantResult: "192.0.2.1:2",
},
"hostname and port": {
endpoint: "foo:2",
defaultPort: 3,
wantResult: "foo:2",
},
"ip": {
endpoint: "192.0.2.1",
defaultPort: 3,
wantResult: "192.0.2.1:3",
},
"hostname": {
endpoint: "foo",
defaultPort: 3,
wantResult: "foo:3",
},
"empty endpoint": {
endpoint: "",
defaultPort: 3,
wantErr: true,
},
"invalid endpoint": {
endpoint: "foo:2:2",
defaultPort: 3,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResult, res)
})
}
}

View File

@ -5,8 +5,9 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/internal/atls"
@ -114,30 +115,30 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
}
// Get empty values from ID file
emptyEndpoint := endpoint == ""
emptyIDs := ownerID == "" && clusterID == ""
if emptyEndpoint || emptyIDs {
if details, err := readIds(fileHandler); err == nil {
if emptyEndpoint {
cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName)
endpoint = details.Endpoint
}
if emptyIDs {
cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName)
ownerID = details.OwnerID
clusterID = details.ClusterID
}
} else if !errors.Is(err, fs.ErrNotExist) {
return verifyFlags{}, err
var idFile clusterIDsFile
if emptyEndpoint || emptyIDs { // Get empty values from ID file
if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &idFile); err != nil {
return verifyFlags{}, fmt.Errorf("reading cluster ID file: %w", err)
}
}
if emptyEndpoint {
cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName)
endpoint = idFile.IP
}
if emptyIDs {
cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName)
ownerID = idFile.OwnerID
clusterID = idFile.ClusterID
}
// Validate
if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("neither owner-id nor cluster-id provided to verify the cluster")
}
endpoint, err = validateEndpoint(endpoint, constants.VerifyServiceNodePortGRPC)
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
}
@ -157,12 +158,21 @@ type verifyFlags struct {
configPath string
}
func readIds(fileHandler file.Handler) (clusterIDsFile, error) {
det := clusterIDsFile{}
if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &det); err != nil {
return clusterIDsFile{}, fmt.Errorf("reading cluster ids: %w", err)
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}
return det, nil
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}
// verifyCompletion handles the completion of CLI arguments. It is frequently called

View File

@ -106,8 +106,8 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
ownerIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"},
wantEndpoint: "192.0.2.1:1234",
idFile: &clusterIDsFile{IP: "192.0.2.1"},
wantEndpoint: "192.0.2.1:30081",
},
"override endpoint from details file": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
@ -115,7 +115,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.2:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"},
idFile: &clusterIDsFile{IP: "192.0.2.1"},
wantEndpoint: "192.0.2.2:1234",
},
"invalid endpoint": {
@ -345,3 +345,59 @@ type stubVerifyAPI struct {
func (a stubVerifyAPI) GetAttestation(context.Context, *verifyproto.GetAttestationRequest) (*verifyproto.GetAttestationResponse, error) {
return a.attestation, a.attestationErr
}
func TestAddPortIfMissing(t *testing.T) {
testCases := map[string]struct {
endpoint string
defaultPort int
wantResult string
wantErr bool
}{
"ip and port": {
endpoint: "192.0.2.1:2",
defaultPort: 3,
wantResult: "192.0.2.1:2",
},
"hostname and port": {
endpoint: "foo:2",
defaultPort: 3,
wantResult: "foo:2",
},
"ip": {
endpoint: "192.0.2.1",
defaultPort: 3,
wantResult: "192.0.2.1:3",
},
"hostname": {
endpoint: "foo",
defaultPort: 3,
wantResult: "foo:3",
},
"empty endpoint": {
endpoint: "",
defaultPort: 3,
wantErr: true,
},
"invalid endpoint": {
endpoint: "foo:2:2",
defaultPort: 3,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
res, err := addPortIfMissing(tc.endpoint, tc.defaultPort)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResult, res)
})
}
}