mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-11 23:49:30 -05:00
Refactor id file interaction
* Use IP instead of endpoint in clusterIDsFile * Move and rename validateEnpoint to addPortIfMissing * Refactor clusterIDsFile handling in verify cmd
This commit is contained in:
parent
c2faa20d6e
commit
7bbcc564bb
@ -3,5 +3,5 @@ package cmd
|
||||
type clusterIDsFile struct {
|
||||
ClusterID string
|
||||
OwnerID string
|
||||
Endpoint string
|
||||
IP string
|
||||
}
|
@ -217,7 +217,7 @@ func writeOutput(resp *initproto.InitResponse, ip string, wr io.Writer, fileHand
|
||||
idFile := clusterIDsFile{
|
||||
ClusterID: clusterID,
|
||||
OwnerID: ownerID,
|
||||
Endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.VerifyServiceNodePortGRPC)),
|
||||
IP: ip,
|
||||
}
|
||||
if err := fileHandler.WriteJSON(constants.ClusterIDsFileName, idFile, file.OptNone); err != nil {
|
||||
return fmt.Errorf("writing Constellation id file: %w", err)
|
||||
|
@ -207,14 +207,14 @@ func TestWriteOutput(t *testing.T) {
|
||||
expectedIDFile := clusterIDsFile{
|
||||
ClusterID: clusterID,
|
||||
OwnerID: ownerID,
|
||||
Endpoint: net.JoinHostPort("ip", strconv.Itoa(constants.VerifyServiceNodePortGRPC)),
|
||||
IP: "cluster-ip",
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
testFs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(testFs)
|
||||
|
||||
err := writeOutput(resp, "ip", &out, fileHandler)
|
||||
err := writeOutput(resp, "cluster-ip", &out, fileHandler)
|
||||
assert.NoError(err)
|
||||
// assert.Contains(out.String(), ownerID)
|
||||
assert.Contains(out.String(), clusterID)
|
||||
|
@ -100,7 +100,7 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
|
||||
}
|
||||
endpoint, err = validateEndpoint(endpoint, constants.BootstrapperPort)
|
||||
endpoint, err = addPortIfMissing(endpoint, constants.BootstrapperPort)
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
|
||||
}
|
||||
|
@ -3,9 +3,6 @@ package cmd
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/gcp"
|
||||
@ -56,20 +53,3 @@ func validInstanceTypeForProvider(cmd *cobra.Command, insType string, provider c
|
||||
return fmt.Errorf("%s isn't a valid cloud platform", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
|
||||
if endpoint == "" {
|
||||
return "", errors.New("endpoint is empty")
|
||||
}
|
||||
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudProvider(t *testing.T) {
|
||||
@ -37,59 +36,3 @@ func TestIsCloudProvider(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEndpoint(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
endpoint string
|
||||
defaultPort int
|
||||
wantResult string
|
||||
wantErr bool
|
||||
}{
|
||||
"ip and port": {
|
||||
endpoint: "192.0.2.1:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:2",
|
||||
},
|
||||
"hostname and port": {
|
||||
endpoint: "foo:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:2",
|
||||
},
|
||||
"ip": {
|
||||
endpoint: "192.0.2.1",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:3",
|
||||
},
|
||||
"hostname": {
|
||||
endpoint: "foo",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:3",
|
||||
},
|
||||
"empty endpoint": {
|
||||
endpoint: "",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid endpoint": {
|
||||
endpoint: "foo:2:2",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResult, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -5,8 +5,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
@ -114,30 +115,30 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags
|
||||
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
|
||||
}
|
||||
|
||||
// Get empty values from ID file
|
||||
emptyEndpoint := endpoint == ""
|
||||
emptyIDs := ownerID == "" && clusterID == ""
|
||||
if emptyEndpoint || emptyIDs {
|
||||
if details, err := readIds(fileHandler); err == nil {
|
||||
if emptyEndpoint {
|
||||
cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName)
|
||||
endpoint = details.Endpoint
|
||||
}
|
||||
if emptyIDs {
|
||||
cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName)
|
||||
ownerID = details.OwnerID
|
||||
clusterID = details.ClusterID
|
||||
}
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return verifyFlags{}, err
|
||||
|
||||
var idFile clusterIDsFile
|
||||
if emptyEndpoint || emptyIDs { // Get empty values from ID file
|
||||
if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &idFile); err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("reading cluster ID file: %w", err)
|
||||
}
|
||||
}
|
||||
if emptyEndpoint {
|
||||
cmd.Printf("Using endpoint from %q. Specify --node-endpoint to override this.\n", constants.ClusterIDsFileName)
|
||||
endpoint = idFile.IP
|
||||
}
|
||||
if emptyIDs {
|
||||
cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName)
|
||||
ownerID = idFile.OwnerID
|
||||
clusterID = idFile.ClusterID
|
||||
}
|
||||
|
||||
// Validate
|
||||
if ownerID == "" && clusterID == "" {
|
||||
return verifyFlags{}, errors.New("neither owner-id nor cluster-id provided to verify the cluster")
|
||||
}
|
||||
endpoint, err = validateEndpoint(endpoint, constants.VerifyServiceNodePortGRPC)
|
||||
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
|
||||
}
|
||||
@ -157,12 +158,21 @@ type verifyFlags struct {
|
||||
configPath string
|
||||
}
|
||||
|
||||
func readIds(fileHandler file.Handler) (clusterIDsFile, error) {
|
||||
det := clusterIDsFile{}
|
||||
if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &det); err != nil {
|
||||
return clusterIDsFile{}, fmt.Errorf("reading cluster ids: %w", err)
|
||||
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
|
||||
if endpoint == "" {
|
||||
return "", errors.New("endpoint is empty")
|
||||
}
|
||||
return det, nil
|
||||
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
// verifyCompletion handles the completion of CLI arguments. It is frequently called
|
||||
|
@ -106,8 +106,8 @@ func TestVerify(t *testing.T) {
|
||||
provider: cloudprovider.GCP,
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"},
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
wantEndpoint: "192.0.2.1:30081",
|
||||
},
|
||||
"override endpoint from details file": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
@ -115,7 +115,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.2:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
idFile: &clusterIDsFile{Endpoint: "192.0.2.1:1234"},
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
wantEndpoint: "192.0.2.2:1234",
|
||||
},
|
||||
"invalid endpoint": {
|
||||
@ -345,3 +345,59 @@ type stubVerifyAPI struct {
|
||||
func (a stubVerifyAPI) GetAttestation(context.Context, *verifyproto.GetAttestationRequest) (*verifyproto.GetAttestationResponse, error) {
|
||||
return a.attestation, a.attestationErr
|
||||
}
|
||||
|
||||
func TestAddPortIfMissing(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
endpoint string
|
||||
defaultPort int
|
||||
wantResult string
|
||||
wantErr bool
|
||||
}{
|
||||
"ip and port": {
|
||||
endpoint: "192.0.2.1:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:2",
|
||||
},
|
||||
"hostname and port": {
|
||||
endpoint: "foo:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:2",
|
||||
},
|
||||
"ip": {
|
||||
endpoint: "192.0.2.1",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:3",
|
||||
},
|
||||
"hostname": {
|
||||
endpoint: "foo",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:3",
|
||||
},
|
||||
"empty endpoint": {
|
||||
endpoint: "",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid endpoint": {
|
||||
endpoint: "foo:2:2",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
res, err := addPortIfMissing(tc.endpoint, tc.defaultPort)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResult, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user