2024-02-08 14:20:01 +00:00

292 lines
7.1 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"log/slog"
"path/filepath"
"strings"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/spf13/cobra"
)
type s3Flags struct {
region string
bucket string
distributionID string
logLevel slog.Level
}
func parseS3Flags(cmd *cobra.Command) (s3Flags, error) {
region, err := cmd.Flags().GetString("region")
if err != nil {
return s3Flags{}, err
}
bucket, err := cmd.Flags().GetString("bucket")
if err != nil {
return s3Flags{}, err
}
distributionID, err := cmd.Flags().GetString("distribution-id")
if err != nil {
return s3Flags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return s3Flags{}, err
}
logLevel := slog.LevelInfo
if verbose {
logLevel = slog.LevelDebug
}
return s3Flags{
region: region,
bucket: bucket,
distributionID: distributionID,
logLevel: logLevel,
}, nil
}
type measurementsFlags struct {
s3Flags
measurementsPath string
signaturePath string
}
func parseUploadMeasurementsFlags(cmd *cobra.Command) (measurementsFlags, error) {
s3, err := parseS3Flags(cmd)
if err != nil {
return measurementsFlags{}, err
}
measurementsPath, err := cmd.Flags().GetString("measurements")
if err != nil {
return measurementsFlags{}, err
}
signaturePath, err := cmd.Flags().GetString("signature")
if err != nil {
return measurementsFlags{}, err
}
return measurementsFlags{
s3Flags: s3,
measurementsPath: measurementsPath,
signaturePath: signaturePath,
}, nil
}
type mergeMeasurementsFlags struct {
out string
logLevel slog.Level
}
func parseMergeMeasurementsFlags(cmd *cobra.Command) (mergeMeasurementsFlags, error) {
out, err := cmd.Flags().GetString("out")
if err != nil {
return mergeMeasurementsFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return mergeMeasurementsFlags{}, err
}
logLevel := slog.LevelInfo
if verbose {
logLevel = slog.LevelDebug
}
return mergeMeasurementsFlags{
out: out,
logLevel: logLevel,
}, nil
}
type envelopeMeasurementsFlags struct {
version versionsapi.Version
csp cloudprovider.Provider
attestationVariant string
in, out string
logLevel slog.Level
}
func parseEnvelopeMeasurementsFlags(cmd *cobra.Command) (envelopeMeasurementsFlags, error) {
version, err := cmd.Flags().GetString("version")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
ver, err := versionsapi.NewVersionFromShortPath(version, versionsapi.VersionKindImage)
if err != nil {
return envelopeMeasurementsFlags{}, err
}
csp, err := cmd.Flags().GetString("csp")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
provider := cloudprovider.FromString(csp)
attestationVariant, err := cmd.Flags().GetString("attestation-variant")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
if provider == cloudprovider.Unknown {
return envelopeMeasurementsFlags{}, errors.New("unknown cloud provider")
}
in, err := cmd.Flags().GetString("in")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
out, err := cmd.Flags().GetString("out")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return envelopeMeasurementsFlags{}, err
}
logLevel := slog.LevelInfo
if verbose {
logLevel = slog.LevelDebug
}
return envelopeMeasurementsFlags{
version: ver,
csp: provider,
attestationVariant: attestationVariant,
in: in,
out: out,
logLevel: logLevel,
}, nil
}
type uplosiFlags struct {
rawImage string
provider cloudprovider.Provider
version versionsapi.Version
attestationVariant string
out string
uplosiPath string
// archiver flags
region string
bucket string
distributionID string
logLevel slog.Level
}
func parseUplosiFlags(cmd *cobra.Command) (uplosiFlags, error) {
rawImage, err := cmd.Flags().GetString("raw-image")
if err != nil {
return uplosiFlags{}, err
}
rawImage, err = filepath.Abs(rawImage)
if err != nil {
return uplosiFlags{}, err
}
// extract csp, attestation variant, and stream from the raw image name
// format: <csp>_<attestation-variant>_<stream>/constellation.raw
var guessedCSP, guessedAttestationVariant, guessedStream string
pathComponents := strings.Split(filepath.ToSlash(rawImage), "/")
if len(pathComponents) >= 2 && pathComponents[len(pathComponents)-1] == "constellation.raw" {
imageMetadata := pathComponents[len(pathComponents)-2]
imageMetadataComponents := strings.Split(imageMetadata, "_")
if len(imageMetadataComponents) == 3 {
guessedCSP = imageMetadataComponents[0]
guessedAttestationVariant = imageMetadataComponents[1]
guessedStream = imageMetadataComponents[2]
}
}
csp, err := cmd.Flags().GetString("csp")
if err != nil || csp == "" {
csp = guessedCSP
}
if csp == "" {
if err != nil {
return uplosiFlags{}, err
}
return uplosiFlags{}, errors.New("csp is required")
}
provider := cloudprovider.FromString(csp)
ref, err := cmd.Flags().GetString("ref")
if err != nil {
return uplosiFlags{}, err
}
stream, err := cmd.Flags().GetString("stream")
if err != nil || stream == "" {
stream = guessedStream
}
if stream == "" {
if err != nil {
return uplosiFlags{}, err
}
return uplosiFlags{}, errors.New("stream is required")
}
version, err := cmd.Flags().GetString("version")
if err != nil {
return uplosiFlags{}, err
}
ver, err := versionsapi.NewVersion(ref, stream, version, versionsapi.VersionKindImage)
if err != nil {
return uplosiFlags{}, err
}
attestationVariant, err := cmd.Flags().GetString("attestation-variant")
if err != nil || attestationVariant == "" {
attestationVariant = guessedAttestationVariant
}
if attestationVariant == "" {
if err != nil {
return uplosiFlags{}, err
}
return uplosiFlags{}, errors.New("attestation-variant is required")
}
out, err := cmd.Flags().GetString("out")
if err != nil {
return uplosiFlags{}, err
}
uplosiPath, err := cmd.Flags().GetString("uplosi-path")
if err != nil {
return uplosiFlags{}, err
}
region, err := cmd.Flags().GetString("region")
if err != nil {
return uplosiFlags{}, err
}
bucket, err := cmd.Flags().GetString("bucket")
if err != nil {
return uplosiFlags{}, err
}
distributionID, err := cmd.Flags().GetString("distribution-id")
if err != nil {
return uplosiFlags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return uplosiFlags{}, err
}
logLevel := slog.LevelInfo
if verbose {
logLevel = slog.LevelDebug
}
return uplosiFlags{
rawImage: rawImage,
provider: provider,
version: ver,
attestationVariant: attestationVariant,
out: out,
uplosiPath: uplosiPath,
region: region,
bucket: bucket,
distributionID: distributionID,
logLevel: logLevel,
}, nil
}