constellation/image/upload/internal/cmd/flags.go

342 lines
7.7 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"os"
"path/filepath"
"time"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/spf13/cobra"
"go.uber.org/zap/zapcore"
)
type commonFlags struct {
rawImage string
pki string
provider cloudprovider.Provider
attestationVariant string
version versionsapi.Version
timestamp time.Time
region string
bucket string
out string
logLevel zapcore.Level
}
func parseCommonFlags(cmd *cobra.Command) (commonFlags, error) {
workspaceDir := os.Getenv("BUILD_WORKSPACE_DIRECTORY")
rawImage, err := cmd.Flags().GetString("raw-image")
if err != nil {
return commonFlags{}, err
}
pki, err := cmd.Flags().GetString("pki")
if err != nil {
return commonFlags{}, err
}
if pki == "" {
pki = filepath.Join(workspaceDir, "image/pki")
}
attestationVariant, err := cmd.Flags().GetString("attestation-variant")
if err != nil {
return commonFlags{}, err
}
version, err := cmd.Flags().GetString("version")
if err != nil {
return commonFlags{}, err
}
ver, err := versionsapi.NewVersionFromShortPath(version, versionsapi.VersionKindImage)
if err != nil {
return commonFlags{}, err
}
timestamp, err := cmd.Flags().GetString("timestamp")
if err != nil {
return commonFlags{}, err
}
if timestamp == "" {
timestamp = time.Now().Format("2006-01-02T15:04:05Z07:00")
}
timestmp, err := time.Parse("2006-01-02T15:04:05Z07:00", timestamp)
if err != nil {
return commonFlags{}, err
}
region, err := cmd.Flags().GetString("region")
if err != nil {
return commonFlags{}, err
}
bucket, err := cmd.Flags().GetString("bucket")
if err != nil {
return commonFlags{}, err
}
out, err := cmd.Flags().GetString("out")
if err != nil {
return commonFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return commonFlags{}, err
}
logLevel := zapcore.InfoLevel
if verbose {
logLevel = zapcore.DebugLevel
}
return commonFlags{
rawImage: rawImage,
pki: pki,
attestationVariant: attestationVariant,
version: ver,
timestamp: timestmp,
region: region,
bucket: bucket,
out: out,
logLevel: logLevel,
}, nil
}
type awsFlags struct {
commonFlags
awsRegion string
awsBucket string
}
func parseAWSFlags(cmd *cobra.Command) (awsFlags, error) {
common, err := parseCommonFlags(cmd)
if err != nil {
return awsFlags{}, err
}
awsRegion, err := cmd.Flags().GetString("aws-region")
if err != nil {
return awsFlags{}, err
}
awsBucket, err := cmd.Flags().GetString("aws-bucket")
if err != nil {
return awsFlags{}, err
}
common.provider = cloudprovider.AWS
return awsFlags{
commonFlags: common,
awsRegion: awsRegion,
awsBucket: awsBucket,
}, nil
}
type azureFlags struct {
commonFlags
azSubscription string
azLocation string
azResourceGroup string
}
func parseAzureFlags(cmd *cobra.Command) (azureFlags, error) {
common, err := parseCommonFlags(cmd)
if err != nil {
return azureFlags{}, err
}
azSubscription, err := cmd.Flags().GetString("az-subscription")
if err != nil {
return azureFlags{}, err
}
azLocation, err := cmd.Flags().GetString("az-location")
if err != nil {
return azureFlags{}, err
}
azResourceGroup, err := cmd.Flags().GetString("az-resource-group")
if err != nil {
return azureFlags{}, err
}
common.provider = cloudprovider.Azure
return azureFlags{
commonFlags: common,
azSubscription: azSubscription,
azLocation: azLocation,
azResourceGroup: azResourceGroup,
}, nil
}
type gcpFlags struct {
commonFlags
gcpProject string
gcpLocation string
gcpBucket string
}
func parseGCPFlags(cmd *cobra.Command) (gcpFlags, error) {
common, err := parseCommonFlags(cmd)
if err != nil {
return gcpFlags{}, err
}
gcpProject, err := cmd.Flags().GetString("gcp-project")
if err != nil {
return gcpFlags{}, err
}
gcpLocation, err := cmd.Flags().GetString("gcp-location")
if err != nil {
return gcpFlags{}, err
}
gcpBucket, err := cmd.Flags().GetString("gcp-bucket")
if err != nil {
return gcpFlags{}, err
}
common.provider = cloudprovider.GCP
return gcpFlags{
commonFlags: common,
gcpProject: gcpProject,
gcpLocation: gcpLocation,
gcpBucket: gcpBucket,
}, nil
}
type s3Flags struct {
region string
bucket string
logLevel zapcore.Level
}
func parseS3Flags(cmd *cobra.Command) (s3Flags, error) {
region, err := cmd.Flags().GetString("region")
if err != nil {
return s3Flags{}, err
}
bucket, err := cmd.Flags().GetString("bucket")
if err != nil {
return s3Flags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return s3Flags{}, err
}
logLevel := zapcore.InfoLevel
if verbose {
logLevel = zapcore.DebugLevel
}
return s3Flags{
region: region,
bucket: bucket,
logLevel: logLevel,
}, nil
}
type measurementsFlags struct {
s3Flags
measurementsPath string
signaturePath string
}
func parseUploadMeasurementsFlags(cmd *cobra.Command) (measurementsFlags, error) {
s3, err := parseS3Flags(cmd)
if err != nil {
return measurementsFlags{}, err
}
measurementsPath, err := cmd.Flags().GetString("measurements")
if err != nil {
return measurementsFlags{}, err
}
signaturePath, err := cmd.Flags().GetString("signature")
if err != nil {
return measurementsFlags{}, err
}
return measurementsFlags{
s3Flags: s3,
measurementsPath: measurementsPath,
signaturePath: signaturePath,
}, nil
}
type mergeMeasurementsFlags struct {
out string
logLevel zapcore.Level
}
func parseMergeMeasurementsFlags(cmd *cobra.Command) (mergeMeasurementsFlags, error) {
out, err := cmd.Flags().GetString("out")
if err != nil {
return mergeMeasurementsFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return mergeMeasurementsFlags{}, err
}
logLevel := zapcore.InfoLevel
if verbose {
logLevel = zapcore.DebugLevel
}
return mergeMeasurementsFlags{
out: out,
logLevel: logLevel,
}, nil
}
type envelopeMeasurementsFlags struct {
version versionsapi.Version
csp cloudprovider.Provider
attestationVariant string
in, out string
logLevel zapcore.Level
}
func parseEnvelopeMeasurementsFlags(cmd *cobra.Command) (envelopeMeasurementsFlags, error) {
version, err := cmd.Flags().GetString("version")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
ver, err := versionsapi.NewVersionFromShortPath(version, versionsapi.VersionKindImage)
if err != nil {
return envelopeMeasurementsFlags{}, err
}
csp, err := cmd.Flags().GetString("csp")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
provider := cloudprovider.FromString(csp)
attestationVariant, err := cmd.Flags().GetString("attestation-variant")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
if provider == cloudprovider.Unknown {
return envelopeMeasurementsFlags{}, errors.New("unknown cloud provider")
}
in, err := cmd.Flags().GetString("in")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
out, err := cmd.Flags().GetString("out")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
logLevel := zapcore.InfoLevel
if verbose {
logLevel = zapcore.DebugLevel
}
return envelopeMeasurementsFlags{
version: ver,
csp: provider,
attestationVariant: attestationVariant,
in: in,
out: out,
logLevel: logLevel,
}, nil
}