mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-05-02 14:26:23 -04:00
cli: refactor flag parsing code (#2425)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
adfe443b28
commit
c52086c5ff
34 changed files with 1490 additions and 1726 deletions
|
@ -26,7 +26,6 @@ import (
|
|||
tpmProto "github.com/google/go-tpm-tools/proto/tpm"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/state"
|
||||
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
|
@ -45,6 +44,7 @@ import (
|
|||
"github.com/google/go-sev-guest/kds"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
|
@ -64,8 +64,39 @@ func NewVerifyCmd() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
type verifyFlags struct {
|
||||
rootFlags
|
||||
endpoint string
|
||||
ownerID string
|
||||
clusterID string
|
||||
output string
|
||||
}
|
||||
|
||||
func (f *verifyFlags) parse(flags *pflag.FlagSet) error {
|
||||
if err := f.rootFlags.parse(flags); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
f.output, err = flags.GetString("output")
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting 'output' flag: %w", err)
|
||||
}
|
||||
f.endpoint, err = flags.GetString("node-endpoint")
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting 'node-endpoint' flag: %w", err)
|
||||
}
|
||||
f.clusterID, err = flags.GetString("cluster-id")
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting 'cluster-id' flag: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type verifyCmd struct {
|
||||
log debugLog
|
||||
fileHandler file.Handler
|
||||
flags verifyFlags
|
||||
log debugLog
|
||||
}
|
||||
|
||||
func runVerify(cmd *cobra.Command, _ []string) error {
|
||||
|
@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error {
|
|||
return nil, fmt.Errorf("invalid output value for formatter: %s", output)
|
||||
}
|
||||
}
|
||||
v := &verifyCmd{log: log}
|
||||
v := &verifyCmd{
|
||||
fileHandler: fileHandler,
|
||||
log: log,
|
||||
}
|
||||
if err := v.flags.parse(cmd.Flags()); err != nil {
|
||||
return err
|
||||
}
|
||||
v.log.Debugf("Using flags: %+v", v.flags)
|
||||
fetcher := attestationconfigapi.NewFetcher()
|
||||
return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher)
|
||||
return v.verify(cmd, verifyClient, formatterFactory, fetcher)
|
||||
}
|
||||
|
||||
type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error)
|
||||
|
||||
func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
|
||||
flags, err := c.parseVerifyFlags(cmd, fileHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing flags: %w", err)
|
||||
}
|
||||
c.log.Debugf("Using flags: %+v", flags)
|
||||
|
||||
c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
|
||||
conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force)
|
||||
func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
|
||||
c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
|
||||
conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force)
|
||||
var configValidationErr *config.ValidationError
|
||||
if errors.As(err, &configValidationErr) {
|
||||
cmd.PrintErrln(configValidationErr.LongMessage())
|
||||
|
@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
|
|||
return fmt.Errorf("loading config file: %w", err)
|
||||
}
|
||||
|
||||
conf.UpdateMAAURL(flags.maaURL)
|
||||
stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
|
||||
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endpoint, err := c.validateEndpointFlag(cmd, stateFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var maaURL string
|
||||
if stateFile.Infrastructure.Azure != nil {
|
||||
maaURL = stateFile.Infrastructure.Azure.AttestationURL
|
||||
}
|
||||
conf.UpdateMAAURL(maaURL)
|
||||
|
||||
c.log.Debugf("Updating expected PCRs")
|
||||
attConfig := conf.GetAttestationConfig()
|
||||
if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil {
|
||||
if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil {
|
||||
return fmt.Errorf("updating expected PCRs: %w", err)
|
||||
}
|
||||
|
||||
|
@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
|
|||
|
||||
rawAttestationDoc, err := verifyClient.Verify(
|
||||
cmd.Context(),
|
||||
flags.endpoint,
|
||||
endpoint,
|
||||
&verifyproto.GetAttestationRequest{
|
||||
Nonce: nonce,
|
||||
},
|
||||
|
@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
|
|||
}
|
||||
|
||||
// certificates are only available for Azure
|
||||
formatter, err := factory(flags.output, conf.GetProvider(), c.log)
|
||||
formatter, err := factory(c.flags.output, conf.GetProvider(), c.log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating formatter: %w", err)
|
||||
}
|
||||
|
@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
|
|||
rawAttestationDoc,
|
||||
conf.Provider.Azure == nil,
|
||||
attConfig.GetMeasurements(),
|
||||
flags.maaURL,
|
||||
maaURL,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("printing attestation document: %w", err)
|
||||
|
@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) {
|
||||
workDir, err := cmd.Flags().GetString("workspace")
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err)
|
||||
func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) {
|
||||
ownerID, clusterID = c.flags.ownerID, c.flags.clusterID
|
||||
if c.flags.clusterID == "" {
|
||||
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||
clusterID = stateFile.ClusterValues.ClusterID
|
||||
}
|
||||
c.log.Debugf("Flag 'workspace' set to %q", workDir)
|
||||
pf := pathprefix.New(workDir)
|
||||
|
||||
ownerID := ""
|
||||
clusterID, err := cmd.Flags().GetString("cluster-id")
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err)
|
||||
}
|
||||
c.log.Debugf("Flag 'cluster-id' set to %q", clusterID)
|
||||
|
||||
endpoint, err := cmd.Flags().GetString("node-endpoint")
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
|
||||
}
|
||||
c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint)
|
||||
|
||||
force, err := cmd.Flags().GetBool("force")
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err)
|
||||
}
|
||||
c.log.Debugf("Flag 'force' set to %t", force)
|
||||
|
||||
output, err := cmd.Flags().GetString("output")
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err)
|
||||
}
|
||||
c.log.Debugf("Flag 'output' set to %t", output)
|
||||
|
||||
// Get empty values from state file
|
||||
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
|
||||
isFileNotFound := errors.Is(err, afero.ErrFileNotFound)
|
||||
if isFileNotFound {
|
||||
c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename))
|
||||
stateFile = state.New() // error compat
|
||||
} else if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
|
||||
emptyEndpoint := endpoint == ""
|
||||
emptyIDs := ownerID == "" && clusterID == ""
|
||||
if emptyEndpoint || emptyIDs {
|
||||
c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename))
|
||||
if emptyEndpoint {
|
||||
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
|
||||
endpoint = stateFile.Infrastructure.ClusterEndpoint
|
||||
}
|
||||
if emptyIDs {
|
||||
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
|
||||
ownerID = stateFile.ClusterValues.OwnerID
|
||||
clusterID = stateFile.ClusterValues.ClusterID
|
||||
}
|
||||
}
|
||||
|
||||
var attestationURL string
|
||||
if stateFile.Infrastructure.Azure != nil {
|
||||
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
|
||||
if ownerID == "" {
|
||||
// We don't want to print warnings until this is implemented again
|
||||
// cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||
ownerID = stateFile.ClusterValues.OwnerID
|
||||
}
|
||||
|
||||
// Validate
|
||||
if ownerID == "" && clusterID == "" {
|
||||
return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster")
|
||||
}
|
||||
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
|
||||
if err != nil {
|
||||
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
|
||||
return "", "", errors.New("cluster-id not provided to verify the cluster")
|
||||
}
|
||||
|
||||
return verifyFlags{
|
||||
endpoint: endpoint,
|
||||
pf: pf,
|
||||
ownerID: ownerID,
|
||||
clusterID: clusterID,
|
||||
output: output,
|
||||
maaURL: attestationURL,
|
||||
force: force,
|
||||
}, nil
|
||||
return ownerID, clusterID, nil
|
||||
}
|
||||
|
||||
type verifyFlags struct {
|
||||
endpoint string
|
||||
ownerID string
|
||||
clusterID string
|
||||
maaURL string
|
||||
output string
|
||||
force bool
|
||||
pf pathprefix.PathPrefixer
|
||||
}
|
||||
|
||||
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
|
||||
func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) {
|
||||
endpoint := c.flags.endpoint
|
||||
if endpoint == "" {
|
||||
return "", errors.New("endpoint is empty")
|
||||
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||
endpoint = stateFile.Infrastructure.ClusterEndpoint
|
||||
}
|
||||
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validating endpoint argument: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// an attestationDocFormatter formats the attestation document.
|
||||
|
@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) {
|
|||
}
|
||||
return instanceInfo, nil
|
||||
}
|
||||
|
||||
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
|
||||
if endpoint == "" {
|
||||
return "", errors.New("endpoint is empty")
|
||||
}
|
||||
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue