cli: refactor flag parsing code (#2425)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-16 15:05:29 +02:00 committed by GitHub
parent adfe443b28
commit c52086c5ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1490 additions and 1726 deletions

View file

@ -26,7 +26,6 @@ import (
tpmProto "github.com/google/go-tpm-tools/proto/tpm"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
@ -45,6 +44,7 @@ import (
"github.com/google/go-sev-guest/kds"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
)
@ -64,8 +64,39 @@ func NewVerifyCmd() *cobra.Command {
return cmd
}
type verifyFlags struct {
rootFlags
endpoint string
ownerID string
clusterID string
output string
}
func (f *verifyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
var err error
f.output, err = flags.GetString("output")
if err != nil {
return fmt.Errorf("getting 'output' flag: %w", err)
}
f.endpoint, err = flags.GetString("node-endpoint")
if err != nil {
return fmt.Errorf("getting 'node-endpoint' flag: %w", err)
}
f.clusterID, err = flags.GetString("cluster-id")
if err != nil {
return fmt.Errorf("getting 'cluster-id' flag: %w", err)
}
return nil
}
type verifyCmd struct {
log debugLog
fileHandler file.Handler
flags verifyFlags
log debugLog
}
func runVerify(cmd *cobra.Command, _ []string) error {
@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return nil, fmt.Errorf("invalid output value for formatter: %s", output)
}
}
v := &verifyCmd{log: log}
v := &verifyCmd{
fileHandler: fileHandler,
log: log,
}
if err := v.flags.parse(cmd.Flags()); err != nil {
return err
}
v.log.Debugf("Using flags: %+v", v.flags)
fetcher := attestationconfigapi.NewFetcher()
return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher)
return v.verify(cmd, verifyClient, formatterFactory, fetcher)
}
type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error)
func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
flags, err := c.parseVerifyFlags(cmd, fileHandler)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
c.log.Debugf("Using flags: %+v", flags)
c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force)
func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return fmt.Errorf("loading config file: %w", err)
}
conf.UpdateMAAURL(flags.maaURL)
stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
if err != nil {
return err
}
endpoint, err := c.validateEndpointFlag(cmd, stateFile)
if err != nil {
return err
}
var maaURL string
if stateFile.Infrastructure.Azure != nil {
maaURL = stateFile.Infrastructure.Azure.AttestationURL
}
conf.UpdateMAAURL(maaURL)
c.log.Debugf("Updating expected PCRs")
attConfig := conf.GetAttestationConfig()
if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil {
if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil {
return fmt.Errorf("updating expected PCRs: %w", err)
}
@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc, err := verifyClient.Verify(
cmd.Context(),
flags.endpoint,
endpoint,
&verifyproto.GetAttestationRequest{
Nonce: nonce,
},
@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
}
// certificates are only available for Azure
formatter, err := factory(flags.output, conf.GetProvider(), c.log)
formatter, err := factory(c.flags.output, conf.GetProvider(), c.log)
if err != nil {
return fmt.Errorf("creating formatter: %w", err)
}
@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc,
conf.Provider.Azure == nil,
attConfig.GetMeasurements(),
flags.maaURL,
maaURL,
)
if err != nil {
return fmt.Errorf("printing attestation document: %w", err)
@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return nil
}
func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err)
func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) {
ownerID, clusterID = c.flags.ownerID, c.flags.clusterID
if c.flags.clusterID == "" {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
clusterID = stateFile.ClusterValues.ClusterID
}
c.log.Debugf("Flag 'workspace' set to %q", workDir)
pf := pathprefix.New(workDir)
ownerID := ""
clusterID, err := cmd.Flags().GetString("cluster-id")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err)
}
c.log.Debugf("Flag 'cluster-id' set to %q", clusterID)
endpoint, err := cmd.Flags().GetString("node-endpoint")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
}
c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("Flag 'force' set to %t", force)
output, err := cmd.Flags().GetString("output")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err)
}
c.log.Debugf("Flag 'output' set to %t", output)
// Get empty values from state file
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
isFileNotFound := errors.Is(err, afero.ErrFileNotFound)
if isFileNotFound {
c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename))
stateFile = state.New() // error compat
} else if err != nil {
return verifyFlags{}, fmt.Errorf("reading state file: %w", err)
}
emptyEndpoint := endpoint == ""
emptyIDs := ownerID == "" && clusterID == ""
if emptyEndpoint || emptyIDs {
c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename))
if emptyEndpoint {
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
if emptyIDs {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
clusterID = stateFile.ClusterValues.ClusterID
}
}
var attestationURL string
if stateFile.Infrastructure.Azure != nil {
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
if ownerID == "" {
// We don't want to print warnings until this is implemented again
// cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
}
// Validate
if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster")
}
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
return "", "", errors.New("cluster-id not provided to verify the cluster")
}
return verifyFlags{
endpoint: endpoint,
pf: pf,
ownerID: ownerID,
clusterID: clusterID,
output: output,
maaURL: attestationURL,
force: force,
}, nil
return ownerID, clusterID, nil
}
type verifyFlags struct {
endpoint string
ownerID string
clusterID string
maaURL string
output string
force bool
pf pathprefix.PathPrefixer
}
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) {
endpoint := c.flags.endpoint
if endpoint == "" {
return "", errors.New("endpoint is empty")
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return "", fmt.Errorf("validating endpoint argument: %w", err)
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
return endpoint, nil
}
// an attestationDocFormatter formats the attestation document.
@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) {
}
return instanceInfo, nil
}
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}