cli: refactor flag parsing code (#2425)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-16 15:05:29 +02:00 committed by GitHub
parent adfe443b28
commit c52086c5ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1490 additions and 1726 deletions

View file

@ -13,7 +13,11 @@ go_library(
"configkubernetesversions.go", "configkubernetesversions.go",
"configmigrate.go", "configmigrate.go",
"create.go", "create.go",
"iam.go",
"iamcreate.go", "iamcreate.go",
"iamcreateaws.go",
"iamcreateazure.go",
"iamcreategcp.go",
"iamdestroy.go", "iamdestroy.go",
"iamupgradeapply.go", "iamupgradeapply.go",
"init.go", "init.go",
@ -87,6 +91,7 @@ go_library(
"@com_github_siderolabs_talos_pkg_machinery//config/encoder", "@com_github_siderolabs_talos_pkg_machinery//config/encoder",
"@com_github_spf13_afero//:afero", "@com_github_spf13_afero//:afero",
"@com_github_spf13_cobra//:cobra", "@com_github_spf13_cobra//:cobra",
"@com_github_spf13_pflag//:pflag",
"@in_gopkg_yaml_v3//:yaml_v3", "@in_gopkg_yaml_v3//:yaml_v3",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/runtime", "@io_k8s_apimachinery//pkg/runtime",

View file

@ -20,3 +20,58 @@ Common filepaths are defined as constants in the global "/internal/constants" pa
To generate workspace correct filepaths for printing, use the functions from the "workspace" package. To generate workspace correct filepaths for printing, use the functions from the "workspace" package.
*/ */
package cmd package cmd
import (
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/spf13/pflag"
)
// rootFlags are flags defined on the root command.
// They are available to all subcommands.
type rootFlags struct {
pathPrefixer pathprefix.PathPrefixer
tfLogLevel terraform.LogLevel
debug bool
force bool
}
// parse flags into the rootFlags struct.
func (f *rootFlags) parse(flags *pflag.FlagSet) error {
var errs error
workspace, err := flags.GetString("workspace")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'workspace' flag: %w", err))
}
f.pathPrefixer = pathprefix.New(workspace)
tfLogString, err := flags.GetString("tf-log")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'tf-log' flag: %w", err))
}
f.tfLogLevel, err = terraform.ParseLogLevel(tfLogString)
if err != nil {
errs = errors.Join(err, fmt.Errorf("parsing 'tf-log' flag: %w", err))
}
f.debug, err = flags.GetBool("debug")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'debug' flag: %w", err))
}
f.force, err = flags.GetBool("force")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'force' flag: %w", err))
}
return errs
}
func must(err error) {
if err != nil {
panic(err)
}
}

View file

@ -14,7 +14,6 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset" "github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi" "github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
@ -26,6 +25,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/sigstore/keyselect" "github.com/edgelesssys/constellation/v2/internal/sigstore/keyselect"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
func newConfigFetchMeasurementsCmd() *cobra.Command { func newConfigFetchMeasurementsCmd() *cobra.Command {
@ -46,14 +46,35 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
} }
type fetchMeasurementsFlags struct { type fetchMeasurementsFlags struct {
rootFlags
measurementsURL *url.URL measurementsURL *url.URL
signatureURL *url.URL signatureURL *url.URL
insecure bool insecure bool
force bool }
pf pathprefix.PathPrefixer
func (f *fetchMeasurementsFlags) parse(flags *pflag.FlagSet) error {
var err error
if err := f.rootFlags.parse(flags); err != nil {
return err
}
f.measurementsURL, err = parseURLFlag(flags, "url")
if err != nil {
return err
}
f.signatureURL, err = parseURLFlag(flags, "signature-url")
if err != nil {
return err
}
f.insecure, err = flags.GetBool("insecure")
if err != nil {
return fmt.Errorf("getting 'insecure' flag: %w", err)
}
return nil
} }
type configFetchMeasurementsCmd struct { type configFetchMeasurementsCmd struct {
flags fetchMeasurementsFlags
canFetchMeasurements bool canFetchMeasurements bool
log debugLog log debugLog
} }
@ -70,6 +91,10 @@ func runConfigFetchMeasurements(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("constructing Rekor client: %w", err) return fmt.Errorf("constructing Rekor client: %w", err)
} }
cfm := &configFetchMeasurementsCmd{log: log, canFetchMeasurements: featureset.CanFetchMeasurements} cfm := &configFetchMeasurementsCmd{log: log, canFetchMeasurements: featureset.CanFetchMeasurements}
if err := cfm.flags.parse(cmd.Flags()); err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
cfm.log.Debugf("Using flags %+v", cfm.flags)
fetcher := attestationconfigapi.NewFetcherWithClient(http.DefaultClient, constants.CDNRepositoryURL) fetcher := attestationconfigapi.NewFetcherWithClient(http.DefaultClient, constants.CDNRepositoryURL)
return cfm.configFetchMeasurements(cmd, sigstore.NewCosignVerifier, rekor, fileHandler, fetcher, http.DefaultClient) return cfm.configFetchMeasurements(cmd, sigstore.NewCosignVerifier, rekor, fileHandler, fetcher, http.DefaultClient)
@ -79,20 +104,14 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
cmd *cobra.Command, newCosignVerifier cosignVerifierConstructor, rekor rekorVerifier, cmd *cobra.Command, newCosignVerifier cosignVerifierConstructor, rekor rekorVerifier,
fileHandler file.Handler, fetcher attestationconfigapi.Fetcher, client *http.Client, fileHandler file.Handler, fetcher attestationconfigapi.Fetcher, client *http.Client,
) error { ) error {
flags, err := cfm.parseFetchMeasurementsFlags(cmd)
if err != nil {
return err
}
cfm.log.Debugf("Using flags %v", flags)
if !cfm.canFetchMeasurements { if !cfm.canFetchMeasurements {
cmd.PrintErrln("Fetching measurements is not supported in the OSS build of the Constellation CLI. Consult the documentation for instructions on where to download the enterprise version.") cmd.PrintErrln("Fetching measurements is not supported in the OSS build of the Constellation CLI. Consult the documentation for instructions on where to download the enterprise version.")
return errors.New("fetching measurements is not supported") return errors.New("fetching measurements is not supported")
} }
cfm.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename)) cfm.log.Debugf("Loading configuration file from %q", cfm.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force) conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, cfm.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -110,7 +129,7 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
defer cancel() defer cancel()
cfm.log.Debugf("Updating URLs") cfm.log.Debugf("Updating URLs")
if err := flags.updateURLs(conf); err != nil { if err := cfm.flags.updateURLs(conf); err != nil {
return err return err
} }
@ -131,11 +150,11 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
var fetchedMeasurements measurements.M var fetchedMeasurements measurements.M
var hash string var hash string
if flags.insecure { if cfm.flags.insecure {
if err := fetchedMeasurements.FetchNoVerify( if err := fetchedMeasurements.FetchNoVerify(
ctx, ctx,
client, client,
flags.measurementsURL, cfm.flags.measurementsURL,
imageVersion, imageVersion,
conf.GetProvider(), conf.GetProvider(),
conf.GetAttestationConfig().GetVariant(), conf.GetAttestationConfig().GetVariant(),
@ -149,8 +168,8 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
ctx, ctx,
client, client,
cosign, cosign,
flags.measurementsURL, cfm.flags.measurementsURL,
flags.signatureURL, cfm.flags.signatureURL,
imageVersion, imageVersion,
conf.GetProvider(), conf.GetProvider(),
conf.GetAttestationConfig().GetVariant(), conf.GetAttestationConfig().GetVariant(),
@ -173,63 +192,11 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil { if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil {
return err return err
} }
cfm.log.Debugf("Configuration written to %s", flags.pf.PrefixPrintablePath(constants.ConfigFilename)) cfm.log.Debugf("Configuration written to %s", cfm.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
cmd.Print("Successfully fetched measurements and updated Configuration\n") cmd.Print("Successfully fetched measurements and updated Configuration\n")
return nil return nil
} }
// parseURLFlag checks that flag can be parsed as URL.
// If no value was provided for flag, nil is returned.
func (cfm *configFetchMeasurementsCmd) parseURLFlag(cmd *cobra.Command, flag string) (*url.URL, error) {
rawURL, err := cmd.Flags().GetString(flag)
if err != nil {
return nil, fmt.Errorf("parsing config generate flags '%s': %w", flag, err)
}
cfm.log.Debugf("Flag %s has raw URL %q", flag, rawURL)
if rawURL != "" {
cfm.log.Debugf("Parsing raw URL")
return url.Parse(rawURL)
}
return nil, nil
}
func (cfm *configFetchMeasurementsCmd) parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return nil, fmt.Errorf("parsing workspace argument: %w", err)
}
measurementsURL, err := cfm.parseURLFlag(cmd, "url")
if err != nil {
return nil, err
}
cfm.log.Debugf("Parsed measurements URL as %v", measurementsURL)
measurementsSignatureURL, err := cfm.parseURLFlag(cmd, "signature-url")
if err != nil {
return nil, err
}
cfm.log.Debugf("Parsed measurements signature URL as %v", measurementsSignatureURL)
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return nil, fmt.Errorf("parsing insecure argument: %w", err)
}
cfm.log.Debugf("Insecure flag is %v", insecure)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return nil, fmt.Errorf("parsing force argument: %w", err)
}
return &fetchMeasurementsFlags{
measurementsURL: measurementsURL,
signatureURL: measurementsSignatureURL,
insecure: insecure,
force: force,
pf: pathprefix.New(workDir),
}, nil
}
func (f *fetchMeasurementsFlags) updateURLs(conf *config.Config) error { func (f *fetchMeasurementsFlags) updateURLs(conf *config.Config) error {
ver, err := versionsapi.NewVersionFromShortPath(conf.Image, versionsapi.VersionKindImage) ver, err := versionsapi.NewVersionFromShortPath(conf.Image, versionsapi.VersionKindImage)
if err != nil { if err != nil {
@ -250,6 +217,19 @@ func (f *fetchMeasurementsFlags) updateURLs(conf *config.Config) error {
return nil return nil
} }
// parseURLFlag checks that flag can be parsed as URL.
// If no value was provided for flag, nil is returned.
func parseURLFlag(flags *pflag.FlagSet, flag string) (*url.URL, error) {
rawURL, err := flags.GetString(flag)
if err != nil {
return nil, fmt.Errorf("getting '%s' flag: %w", flag, err)
}
if rawURL != "" {
return url.Parse(rawURL)
}
return nil, nil
}
type rekorVerifier interface { type rekorVerifier interface {
SearchByHash(context.Context, string) ([]string, error) SearchByHash(context.Context, string) ([]string, error)
VerifyEntry(context.Context, string, string) error VerifyEntry(context.Context, string, string) error

View file

@ -13,7 +13,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
@ -39,11 +38,11 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
urlFlag string urlFlag string
signatureURLFlag string signatureURLFlag string
forceFlag bool forceFlag bool
wantFlags *fetchMeasurementsFlags wantFlags fetchMeasurementsFlags
wantErr bool wantErr bool
}{ }{
"default": { "default": {
wantFlags: &fetchMeasurementsFlags{ wantFlags: fetchMeasurementsFlags{
measurementsURL: nil, measurementsURL: nil,
signatureURL: nil, signatureURL: nil,
}, },
@ -51,7 +50,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
"url": { "url": {
urlFlag: "https://some.other.url/with/path", urlFlag: "https://some.other.url/with/path",
signatureURLFlag: "https://some.other.url/with/path.sig", signatureURLFlag: "https://some.other.url/with/path.sig",
wantFlags: &fetchMeasurementsFlags{ wantFlags: fetchMeasurementsFlags{
measurementsURL: urlMustParse("https://some.other.url/with/path"), measurementsURL: urlMustParse("https://some.other.url/with/path"),
signatureURL: urlMustParse("https://some.other.url/with/path.sig"), signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
}, },
@ -69,7 +68,9 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
cmd := newConfigFetchMeasurementsCmd() cmd := newConfigFetchMeasurementsCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", false, "") // register persistent flag manually cmd.Flags().Bool("force", false, "")
cmd.Flags().Bool("debug", false, "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.urlFlag != "" { if tc.urlFlag != "" {
require.NoError(cmd.Flags().Set("url", tc.urlFlag)) require.NoError(cmd.Flags().Set("url", tc.urlFlag))
@ -77,8 +78,8 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
if tc.signatureURLFlag != "" { if tc.signatureURLFlag != "" {
require.NoError(cmd.Flags().Set("signature-url", tc.signatureURLFlag)) require.NoError(cmd.Flags().Set("signature-url", tc.signatureURLFlag))
} }
cfm := &configFetchMeasurementsCmd{log: logger.NewTest(t)} var flags fetchMeasurementsFlags
flags, err := cfm.parseFetchMeasurementsFlags(cmd) err := flags.parse(cmd.Flags())
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
@ -270,9 +271,6 @@ func TestConfigFetchMeasurements(t *testing.T) {
require := require.New(t) require := require.New(t)
cmd := newConfigFetchMeasurementsCmd() cmd := newConfigFetchMeasurementsCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
require.NoError(cmd.Flags().Set("insecure", strconv.FormatBool(tc.insecureFlag)))
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP) gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
@ -281,6 +279,8 @@ func TestConfigFetchMeasurements(t *testing.T) {
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll) err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)
require.NoError(err) require.NoError(err)
cfm := &configFetchMeasurementsCmd{canFetchMeasurements: true, log: logger.NewTest(t)} cfm := &configFetchMeasurementsCmd{canFetchMeasurements: true, log: logger.NewTest(t)}
cfm.flags.insecure = tc.insecureFlag
cfm.flags.force = true
err = cfm.configFetchMeasurements(cmd, tc.cosign, tc.rekor, fileHandler, stubAttestationFetcher{}, client) err = cfm.configFetchMeasurements(cmd, tc.cosign, tc.rekor, fileHandler, stubAttestationFetcher{}, client)
if tc.wantErr { if tc.wantErr {

View file

@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
@ -19,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
@ -41,12 +41,33 @@ func newConfigGenerateCmd() *cobra.Command {
} }
type generateFlags struct { type generateFlags struct {
pf pathprefix.PathPrefixer rootFlags
k8sVersion versions.ValidK8sVersion k8sVersion versions.ValidK8sVersion
attestationVariant variant.Variant attestationVariant variant.Variant
} }
func (f *generateFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
k8sVersion, err := parseK8sFlag(flags)
if err != nil {
return err
}
f.k8sVersion = k8sVersion
variant, err := parseAttestationFlag(flags)
if err != nil {
return err
}
f.attestationVariant = variant
return nil
}
type configGenerateCmd struct { type configGenerateCmd struct {
flags generateFlags
log debugLog log debugLog
} }
@ -56,31 +77,32 @@ func runConfigGenerate(cmd *cobra.Command, args []string) error {
return fmt.Errorf("creating logger: %w", err) return fmt.Errorf("creating logger: %w", err)
} }
defer log.Sync() defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
provider := cloudprovider.FromString(args[0]) provider := cloudprovider.FromString(args[0])
cg := &configGenerateCmd{log: log} cg := &configGenerateCmd{log: log}
if err := cg.flags.parse(cmd.Flags()); err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
log.Debugf("Parsed flags as %+v", cg.flags)
return cg.configGenerate(cmd, fileHandler, provider, args[0]) return cg.configGenerate(cmd, fileHandler, provider, args[0])
} }
func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file.Handler, provider cloudprovider.Provider, rawProvider string) error { func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file.Handler, provider cloudprovider.Provider, rawProvider string) error {
flags, err := parseGenerateFlags(cmd)
if err != nil {
return err
}
cg.log.Debugf("Parsed flags as %v", flags)
cg.log.Debugf("Using cloud provider %s", provider.String()) cg.log.Debugf("Using cloud provider %s", provider.String())
conf, err := createConfigWithAttestationVariant(provider, rawProvider, flags.attestationVariant) conf, err := createConfigWithAttestationVariant(provider, rawProvider, cg.flags.attestationVariant)
if err != nil { if err != nil {
return fmt.Errorf("creating config: %w", err) return fmt.Errorf("creating config: %w", err)
} }
conf.KubernetesVersion = flags.k8sVersion conf.KubernetesVersion = cg.flags.k8sVersion
cg.log.Debugf("Writing YAML data to configuration file") cg.log.Debugf("Writing YAML data to configuration file")
if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptMkdirAll); err != nil { if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptMkdirAll); err != nil {
return err return err
} }
cmd.Println("Config file written to", flags.pf.PrefixPrintablePath(constants.ConfigFilename)) cmd.Println("Config file written to", cg.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
cmd.Println("Please fill in your CSP-specific configuration before proceeding.") cmd.Println("Please fill in your CSP-specific configuration before proceeding.")
cmd.Println("For more information refer to the documentation:") cmd.Println("For more information refer to the documentation:")
cmd.Println("\thttps://docs.edgeless.systems/constellation/getting-started/first-steps") cmd.Println("\thttps://docs.edgeless.systems/constellation/getting-started/first-steps")
@ -123,46 +145,6 @@ func createConfig(provider cloudprovider.Provider) *config.Config {
return res return res
} }
func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing workspace flag: %w", err)
}
k8sVersion, err := cmd.Flags().GetString("kubernetes")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing Kubernetes flag: %w", err)
}
resolvedVersion, err := versions.ResolveK8sPatchVersion(k8sVersion)
if err != nil {
return generateFlags{}, fmt.Errorf("resolving kubernetes patch version from flag: %w", err)
}
validK8sVersion, err := versions.NewValidK8sVersion(resolvedVersion, true)
if err != nil {
return generateFlags{}, fmt.Errorf("resolving Kubernetes version from flag: %w", err)
}
attestationString, err := cmd.Flags().GetString("attestation")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing attestation flag: %w", err)
}
var attestationVariant variant.Variant
// if no attestation variant is specified, use the default for the cloud provider
if attestationString == "" {
attestationVariant = variant.Dummy{}
} else {
attestationVariant, err = variant.FromString(attestationString)
if err != nil {
return generateFlags{}, fmt.Errorf("invalid attestation variant: %s", attestationString)
}
}
return generateFlags{
pf: pathprefix.New(workDir),
k8sVersion: validK8sVersion,
attestationVariant: attestationVariant,
}, nil
}
// generateCompletion handles the completion of the create command. It is frequently called // generateCompletion handles the completion of the create command. It is frequently called
// while the user types arguments of the command to suggest completion. // while the user types arguments of the command to suggest completion.
func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) { func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) {
@ -185,3 +167,39 @@ func toString[T any](t []T) []string {
} }
return res return res
} }
func parseK8sFlag(flags *pflag.FlagSet) (versions.ValidK8sVersion, error) {
versionString, err := flags.GetString("kubernetes")
if err != nil {
return "", fmt.Errorf("getting kubernetes flag: %w", err)
}
resolvedVersion, err := versions.ResolveK8sPatchVersion(versionString)
if err != nil {
return "", fmt.Errorf("resolving kubernetes patch version from flag: %w", err)
}
k8sVersion, err := versions.NewValidK8sVersion(resolvedVersion, true)
if err != nil {
return "", fmt.Errorf("resolving Kubernetes version from flag: %w", err)
}
return k8sVersion, nil
}
func parseAttestationFlag(flags *pflag.FlagSet) (variant.Variant, error) {
attestationString, err := flags.GetString("attestation")
if err != nil {
return nil, fmt.Errorf("getting attestation flag: %w", err)
}
var attestationVariant variant.Variant
// if no attestation variant is specified, use the default for the cloud provider
if attestationString == "" {
attestationVariant = variant.Dummy{}
} else {
attestationVariant, err = variant.FromString(attestationString)
if err != nil {
return nil, fmt.Errorf("invalid attestation variant: %s", attestationString)
}
}
return attestationVariant, nil
}

View file

@ -19,13 +19,12 @@ import (
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
func TestConfigGenerateKubernetesVersion(t *testing.T) { func TestParseKubernetesVersion(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
version string version string
wantErr bool wantErr bool
@ -68,22 +67,18 @@ func TestConfigGenerateKubernetesVersion(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs()) flags := newConfigGenerateCmd().Flags()
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
if tc.version != "" { if tc.version != "" {
err := cmd.Flags().Set("kubernetes", tc.version) require.NoError(flags.Set("kubernetes", tc.version))
require.NoError(err)
} }
cg := &configGenerateCmd{log: logger.NewTest(t)} version, err := parseK8sFlag(flags)
err := cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
} }
assert.NoError(err) assert.NoError(err)
assert.Equal(versions.Default, version)
}) })
} }
} }
@ -94,9 +89,14 @@ func TestConfigGenerateDefault(t *testing.T) {
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd() cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cg := &configGenerateCmd{log: logger.NewTest(t)} cg := &configGenerateCmd{
log: logger.NewTest(t),
flags: generateFlags{
attestationVariant: variant.Dummy{},
k8sVersion: versions.Default,
},
}
require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")) require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, ""))
var readConfig config.Config var readConfig config.Config
@ -106,53 +106,47 @@ func TestConfigGenerateDefault(t *testing.T) {
} }
func TestConfigGenerateDefaultProviderSpecific(t *testing.T) { func TestConfigGenerateDefaultProviderSpecific(t *testing.T) {
providers := []cloudprovider.Provider{ testCases := map[string]struct {
cloudprovider.AWS, provider cloudprovider.Provider
cloudprovider.Azure, rawProvider string
cloudprovider.GCP, }{
cloudprovider.OpenStack, "aws": {
provider: cloudprovider.AWS,
},
"azure": {
provider: cloudprovider.Azure,
},
"gcp": {
provider: cloudprovider.GCP,
},
"openstack": {
provider: cloudprovider.OpenStack,
},
"stackit": {
provider: cloudprovider.OpenStack,
rawProvider: "stackit",
},
} }
for _, provider := range providers { for name, tc := range testCases {
t.Run(provider.String(), func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd() cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
wantConf := config.Default() wantConf := config.Default().WithOpenStackProviderDefaults(tc.rawProvider)
wantConf.RemoveProviderAndAttestationExcept(provider) wantConf.RemoveProviderAndAttestationExcept(tc.provider)
cg := &configGenerateCmd{log: logger.NewTest(t)} cg := &configGenerateCmd{
require.NoError(cg.configGenerate(cmd, fileHandler, provider, "")) log: logger.NewTest(t),
flags: generateFlags{
var readConfig config.Config attestationVariant: variant.Dummy{},
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig) k8sVersion: versions.Default,
assert.NoError(err) },
assert.Equal(*wantConf, readConfig)
})
} }
} require.NoError(cg.configGenerate(cmd, fileHandler, tc.provider, tc.rawProvider))
func TestConfigGenerateWithStackIt(t *testing.T) {
openStackProviders := []string{"stackit"}
for _, openStackProvider := range openStackProviders {
t.Run(openStackProvider, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
wantConf := config.Default().WithOpenStackProviderDefaults(openStackProvider)
wantConf.RemoveProviderAndAttestationExcept(cloudprovider.OpenStack)
cg := &configGenerateCmd{log: logger.NewTest(t)}
require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.OpenStack, openStackProvider))
var readConfig config.Config var readConfig config.Config
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig) err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
@ -168,9 +162,11 @@ func TestConfigGenerateDefaultExists(t *testing.T) {
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone)) require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone))
cmd := newConfigGenerateCmd() cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cg := &configGenerateCmd{log: logger.NewTest(t)} cg := &configGenerateCmd{
log: logger.NewTest(t),
flags: generateFlags{attestationVariant: variant.Dummy{}},
}
require.Error(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")) require.Error(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, ""))
} }
@ -247,64 +243,61 @@ func TestValidProviderAttestationCombination(t *testing.T) {
} }
} }
func TestAttestationArgument(t *testing.T) { func TestParseAttestationFlag(t *testing.T) {
defaultAttestation := config.Default().Attestation testCases := map[string]struct {
tests := []struct { wantErr bool
name string attestationFlag string
provider cloudprovider.Provider wantVariant variant.Variant
expectErr bool
expectedCfg config.AttestationConfig
setFlag func(*cobra.Command) error
}{ }{
{ "invalid": {
name: "InvalidAttestationArgument", wantErr: true,
provider: cloudprovider.Unknown, attestationFlag: "unknown",
expectErr: true,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "unknown")
}, },
"AzureTrustedLaunch": {
attestationFlag: "azure-trustedlaunch",
wantVariant: variant.AzureTrustedLaunch{},
}, },
{ "AzureSEVSNP": {
name: "ValidAttestationArgument", attestationFlag: "azure-sev-snp",
provider: cloudprovider.Azure, wantVariant: variant.AzureSEVSNP{},
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "azure-trustedlaunch")
}, },
expectedCfg: config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch}, "AWSSEVSNP": {
attestationFlag: "aws-sev-snp",
wantVariant: variant.AWSSEVSNP{},
}, },
{ "AWSNitroTPM": {
name: "WithoutAttestationArgument", attestationFlag: "aws-nitro-tpm",
provider: cloudprovider.Azure, wantVariant: variant.AWSNitroTPM{},
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return nil
}, },
expectedCfg: config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP}, "GCPSEVES": {
attestationFlag: "gcp-sev-es",
wantVariant: variant.GCPSEVES{},
},
"QEMUVTPM": {
attestationFlag: "qemu-vtpm",
wantVariant: variant.QEMUVTPM{},
},
"no flag": {
wantVariant: variant.Dummy{},
}, },
} }
for _, test := range tests { for name, tc := range testCases {
t.Run(test.name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := assert.New(t) require := require.New(t)
assert := assert.New(t) assert := assert.New(t)
cmd := newConfigGenerateCmd() cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually if tc.attestationFlag != "" {
require.NoError(test.setFlag(cmd)) require.NoError(cmd.Flags().Set("attestation", tc.attestationFlag))
fileHandler := file.NewHandler(afero.NewMemMapFs())
cg := &configGenerateCmd{log: logger.NewTest(t)}
err := cg.configGenerate(cmd, fileHandler, test.provider, "")
if test.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
var readConfig config.Config
require.NoError(fileHandler.ReadYAML(constants.ConfigFilename, &readConfig))
assert.Equal(test.expectedCfg, readConfig.Attestation)
} }
attestation, err := parseAttestationFlag(cmd.Flags())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.True(tc.wantVariant.Equal(attestation))
}) })
} }
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/semver" "github.com/edgelesssys/constellation/v2/internal/semver"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
// NewCreateCmd returns a new cobra.Command for the create command. // NewCreateCmd returns a new cobra.Command for the create command.
@ -39,9 +40,29 @@ func NewCreateCmd() *cobra.Command {
return cmd return cmd
} }
// createFlags contains the parsed flags of the create command.
type createFlags struct {
rootFlags
yes bool
}
// parse parses the flags of the create command.
func (f *createFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
type createCmd struct { type createCmd struct {
log debugLog log debugLog
pf pathprefix.PathPrefixer flags createFlags
} }
func runCreate(cmd *cobra.Command, _ []string) error { func runCreate(cmd *cobra.Command, _ []string) error {
@ -59,22 +80,22 @@ func runCreate(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
creator := cloudcmd.NewCreator(spinner) creator := cloudcmd.NewCreator(spinner)
c := &createCmd{log: log} c := &createCmd{log: log}
if err := c.flags.parse(cmd.Flags()); err != nil {
return err
}
c.log.Debugf("Using flags: %+v", c.flags)
fetcher := attestationconfigapi.NewFetcher() fetcher := attestationconfigapi.NewFetcher()
return c.create(cmd, creator, fileHandler, spinner, fetcher) return c.create(cmd, creator, fileHandler, spinner, fetcher)
} }
func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, spinner spinnerInterf, fetcher attestationconfigapi.Fetcher) (retErr error) { func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, spinner spinnerInterf, fetcher attestationconfigapi.Fetcher) (retErr error) {
flags, err := c.parseCreateFlags(cmd)
if err != nil {
return err
}
c.log.Debugf("Using flags: %+v", flags)
if err := c.checkDirClean(fileHandler); err != nil { if err := c.checkDirClean(fileHandler); err != nil {
return err return err
} }
c.log.Debugf("Loading configuration file from %q", c.pf.PrefixPrintablePath(constants.ConfigFilename)) c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force) conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, c.flags.force)
c.log.Debugf("Configuration file loaded: %+v", conf) c.log.Debugf("Configuration file loaded: %+v", conf)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
@ -83,7 +104,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
if err != nil { if err != nil {
return err return err
} }
if !flags.force { if !c.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil { if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err return err
} }
@ -137,7 +158,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
c.log.Debugf("Creating %d additional node groups: %v", len(otherGroupNames), otherGroupNames) c.log.Debugf("Creating %d additional node groups: %v", len(otherGroupNames), otherGroupNames)
} }
if !flags.yes { if !c.flags.yes {
// Ask user to confirm action. // Ask user to confirm action.
cmd.Printf("The following Constellation cluster will be created:\n") cmd.Printf("The following Constellation cluster will be created:\n")
cmd.Printf(" %d control-plane node%s of type %s will be created.\n", controlPlaneGroup.InitialCount, isPlural(controlPlaneGroup.InitialCount), controlPlaneGroup.InstanceType) cmd.Printf(" %d control-plane node%s of type %s will be created.\n", controlPlaneGroup.InitialCount, isPlural(controlPlaneGroup.InitialCount), controlPlaneGroup.InstanceType)
@ -160,13 +181,13 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
opts := cloudcmd.CreateOptions{ opts := cloudcmd.CreateOptions{
Provider: provider, Provider: provider,
Config: conf, Config: conf,
TFLogLevel: flags.tfLogLevel, TFLogLevel: c.flags.tfLogLevel,
TFWorkspace: constants.TerraformWorkingDir, TFWorkspace: constants.TerraformWorkingDir,
} }
infraState, err := creator.Create(cmd.Context(), opts) infraState, err := creator.Create(cmd.Context(), opts)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
return translateCreateErrors(cmd, c.pf, err) return translateCreateErrors(cmd, c.flags.pathPrefixer, err)
} }
c.log.Debugf("Successfully created the cloud resources for the cluster") c.log.Debugf("Successfully created the cloud resources for the cluster")
@ -179,64 +200,28 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
return nil return nil
} }
// parseCreateFlags parses the flags of the create command.
func (c *createCmd) parseCreateFlags(cmd *cobra.Command) (createFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return createFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
c.log.Debugf("Yes flag is %t", yes)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return createFlags{}, fmt.Errorf("parsing config path argument: %w", err)
}
c.log.Debugf("Workspace set to %q", workDir)
c.pf = pathprefix.New(workDir)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return createFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("force flag is %t", force)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return createFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return createFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
c.log.Debugf("Terraform logs will be written into %s at level %s", c.pf.PrefixPrintablePath(constants.TerraformLogFile), logLevel.String())
return createFlags{
tfLogLevel: logLevel,
force: force,
yes: yes,
}, nil
}
// createFlags contains the parsed flags of the create command.
type createFlags struct {
tfLogLevel terraform.LogLevel
force bool
yes bool
}
// checkDirClean checks if files of a previous Constellation are left in the current working dir. // checkDirClean checks if files of a previous Constellation are left in the current working dir.
func (c *createCmd) checkDirClean(fileHandler file.Handler) error { func (c *createCmd) checkDirClean(fileHandler file.Handler) error {
c.log.Debugf("Checking admin configuration file") c.log.Debugf("Checking admin configuration file")
if _, err := fileHandler.Stat(constants.AdminConfFilename); !errors.Is(err, fs.ErrNotExist) { if _, err := fileHandler.Stat(constants.AdminConfFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", c.pf.PrefixPrintablePath(constants.AdminConfFilename)) return fmt.Errorf(
"file '%s' already exists in working directory, run 'constellation terminate' before creating a new one",
c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename),
)
} }
c.log.Debugf("Checking master secrets file") c.log.Debugf("Checking master secrets file")
if _, err := fileHandler.Stat(constants.MasterSecretFilename); !errors.Is(err, fs.ErrNotExist) { if _, err := fileHandler.Stat(constants.MasterSecretFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory. Constellation won't overwrite previous master secrets. Move it somewhere or delete it before creating a new cluster", c.pf.PrefixPrintablePath(constants.MasterSecretFilename)) return fmt.Errorf(
"file '%s' already exists in working directory. Constellation won't overwrite previous master secrets. Move it somewhere or delete it before creating a new cluster",
c.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename),
)
} }
c.log.Debugf("Checking state file") c.log.Debugf("Checking state file")
if _, err := fileHandler.Stat(constants.StateFilename); !errors.Is(err, fs.ErrNotExist) { if _, err := fileHandler.Stat(constants.StateFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory. Constellation won't overwrite previous cluster state. Move it somewhere or delete it before creating a new cluster", c.pf.PrefixPrintablePath(constants.StateFilename)) return fmt.Errorf(
"file '%s' already exists in working directory. Constellation won't overwrite previous cluster state. Move it somewhere or delete it before creating a new cluster",
c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
)
} }
return nil return nil
@ -270,12 +255,6 @@ func isPlural(count int) string {
return "s" return "s"
} }
func must(err error) {
if err != nil {
panic(err)
}
}
// validateCLIandConstellationVersionAreEqual checks if the image and microservice version are equal (down to patch level) to the CLI version. // validateCLIandConstellationVersionAreEqual checks if the image and microservice version are equal (down to patch level) to the CLI version.
func validateCLIandConstellationVersionAreEqual(cliVersion semver.Semver, imageVersion string, microserviceVersion semver.Semver) error { func validateCLIandConstellationVersionAreEqual(cliVersion semver.Semver, imageVersion string, microserviceVersion semver.Semver) error {
parsedImageVersion, err := versionsapi.NewVersionFromShortPath(imageVersion, versionsapi.VersionKindImage) parsedImageVersion, err := versionsapi.NewVersionFromShortPath(imageVersion, versionsapi.VersionKindImage)

View file

@ -133,16 +133,9 @@ func TestCreate(t *testing.T) {
cmd.SetOut(&bytes.Buffer{}) cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
cmd.Flags().String("tf-log", "NONE", "") // register persistent flag manually
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider)) fileHandler := file.NewHandler(tc.setupFs(require, tc.provider))
c := &createCmd{log: logger.NewTest(t)} c := &createCmd{log: logger.NewTest(t), flags: createFlags{yes: tc.yesFlag}}
err := c.create(cmd, tc.creator, fileHandler, &nopSpinner{}, stubAttestationFetcher{}) err := c.create(cmd, tc.creator, fileHandler, &nopSpinner{}, stubAttestationFetcher{})
if tc.wantErr { if tc.wantErr {

23
cli/internal/cmd/iam.go Normal file
View file

@ -0,0 +1,23 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import "github.com/spf13/cobra"
// NewIAMCmd returns a new cobra.Command for the iam parent command. It needs another verb and does nothing on its own.
func NewIAMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "iam",
Short: "Work with the IAM configuration on your cloud provider",
Long: "Work with the IAM configuration on your cloud provider.",
Args: cobra.ExactArgs(0),
}
cmd.AddCommand(newIAMCreateCmd())
cmd.AddCommand(newIAMDestroyCmd())
cmd.AddCommand(newIAMUpgradeCmd())
return cmd
}

View file

@ -11,17 +11,15 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"regexp" "regexp"
"strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
var ( var (
@ -33,22 +31,7 @@ var (
gcpIDRegex = regexp.MustCompile(`^[a-z][-a-z0-9]{4,28}[a-z0-9]$`) gcpIDRegex = regexp.MustCompile(`^[a-z][-a-z0-9]{4,28}[a-z0-9]$`)
) )
// NewIAMCmd returns a new cobra.Command for the iam parent command. It needs another verb and does nothing on its own. // newIAMCreateCmd returns a new cobra.Command for the iam create parent command. It needs another verb, and does nothing on its own.
func NewIAMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "iam",
Short: "Work with the IAM configuration on your cloud provider",
Long: "Work with the IAM configuration on your cloud provider.",
Args: cobra.ExactArgs(0),
}
cmd.AddCommand(newIAMCreateCmd())
cmd.AddCommand(newIAMDestroyCmd())
cmd.AddCommand(newIAMUpgradeCmd())
return cmd
}
// NewIAMCreateCmd returns a new cobra.Command for the iam create parent command. It needs another verb, and does nothing on its own.
func newIAMCreateCmd() *cobra.Command { func newIAMCreateCmd() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "create", Use: "create",
@ -67,135 +50,54 @@ func newIAMCreateCmd() *cobra.Command {
return cmd return cmd
} }
// newIAMCreateAWSCmd returns a new cobra.Command for the iam create aws command. type iamCreateFlags struct {
func newIAMCreateAWSCmd() *cobra.Command { rootFlags
cmd := &cobra.Command{ yes bool
Use: "aws", updateConfig bool
Short: "Create IAM configuration on AWS for your Constellation cluster",
Long: "Create IAM configuration on AWS for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.AWS),
}
cmd.Flags().String("prefix", "", "name prefix for all resources (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "prefix"))
cmd.Flags().String("zone", "", "AWS availability zone the resources will be created in, e.g., us-east-2a (required)\n"+
"See the Constellation docs for a list of currently supported regions.")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
return cmd
} }
// newIAMCreateAzureCmd returns a new cobra.Command for the iam create azure command. func (f *iamCreateFlags) parse(flags *pflag.FlagSet) error {
func newIAMCreateAzureCmd() *cobra.Command { var err error
cmd := &cobra.Command{ if err = f.rootFlags.parse(flags); err != nil {
Use: "azure", return err
Short: "Create IAM configuration on Microsoft Azure for your Constellation cluster",
Long: "Create IAM configuration on Microsoft Azure for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.Azure),
} }
f.yes, err = flags.GetBool("yes")
cmd.Flags().String("resourceGroup", "", "name prefix of the two resource groups your cluster / IAM resources will be created in (required)") if err != nil {
must(cobra.MarkFlagRequired(cmd.Flags(), "resourceGroup")) return fmt.Errorf("getting 'yes' flag: %w", err)
cmd.Flags().String("region", "", "region the resources will be created in, e.g., westus (required)") }
must(cobra.MarkFlagRequired(cmd.Flags(), "region")) f.updateConfig, err = flags.GetBool("update-config")
cmd.Flags().String("servicePrincipal", "", "name of the service principal that will be created (required)") if err != nil {
must(cobra.MarkFlagRequired(cmd.Flags(), "servicePrincipal")) return fmt.Errorf("getting 'update-config' flag: %w", err)
return cmd }
return nil
} }
// NewIAMCreateGCPCmd returns a new cobra.Command for the iam create gcp command. func runIAMCreate(cmd *cobra.Command, providerCreator providerIAMCreator, provider cloudprovider.Provider) error {
func newIAMCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create IAM configuration on GCP for your Constellation cluster",
Long: "Create IAM configuration on GCP for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.GCP),
}
cmd.Flags().String("zone", "", "GCP zone the cluster will be deployed in (required)\n"+
"Find a list of available zones here: https://cloud.google.com/compute/docs/regions-zones#available")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
cmd.Flags().String("serviceAccountID", "", "ID for the service account that will be created (required)\n"+
"Must be 6 to 30 lowercase letters, digits, or hyphens.")
must(cobra.MarkFlagRequired(cmd.Flags(), "serviceAccountID"))
cmd.Flags().String("projectID", "", "ID of the GCP project the configuration will be created in (required)\n"+
"Find it on the welcome screen of your project: https://console.cloud.google.com/welcome")
must(cobra.MarkFlagRequired(cmd.Flags(), "projectID"))
return cmd
}
// createRunIAMFunc is the entrypoint for the iam create command. It sets up the iamCreator
// and starts IAM creation for the specific cloud provider.
func createRunIAMFunc(provider cloudprovider.Provider) func(cmd *cobra.Command, args []string) error {
var providerCreator func(pf pathprefix.PathPrefixer) providerIAMCreator
switch provider {
case cloudprovider.AWS:
providerCreator = func(pathprefix.PathPrefixer) providerIAMCreator { return &awsIAMCreator{} }
case cloudprovider.Azure:
providerCreator = func(pathprefix.PathPrefixer) providerIAMCreator { return &azureIAMCreator{} }
case cloudprovider.GCP:
providerCreator = func(pf pathprefix.PathPrefixer) providerIAMCreator {
return &gcpIAMCreator{pf}
}
default:
return func(cmd *cobra.Command, args []string) error {
return fmt.Errorf("unknown provider %s", provider)
}
}
return func(cmd *cobra.Command, args []string) error {
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return fmt.Errorf("parsing workspace string: %w", err)
}
pf := pathprefix.New(workDir)
iamCreator, err := newIAMCreator(cmd, pf, logLevel)
if err != nil {
return fmt.Errorf("creating iamCreator: %w", err)
}
defer iamCreator.spinner.Stop()
defer iamCreator.log.Sync()
iamCreator.provider = provider
iamCreator.providerCreator = providerCreator(pf)
return iamCreator.create(cmd.Context())
}
}
// newIAMCreator creates a new iamiamCreator.
func newIAMCreator(cmd *cobra.Command, pf pathprefix.PathPrefixer, logLevel terraform.LogLevel) (*iamCreator, error) {
spinner, err := newSpinnerOrStderr(cmd) spinner, err := newSpinnerOrStderr(cmd)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating spinner: %w", err) return fmt.Errorf("creating spinner: %w", err)
} }
defer spinner.Stop()
log, err := newCLILogger(cmd) log, err := newCLILogger(cmd)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating logger: %w", err) return fmt.Errorf("creating logger: %w", err)
} }
log.Debugf("Terraform logs will be written into %s at level %s", pf.PrefixPrintablePath(constants.TerraformLogFile), logLevel.String()) defer log.Sync()
return &iamCreator{ iamCreator := &iamCreator{
cmd: cmd, cmd: cmd,
spinner: spinner, spinner: spinner,
log: log, log: log,
creator: cloudcmd.NewIAMCreator(spinner), creator: cloudcmd.NewIAMCreator(spinner),
fileHandler: file.NewHandler(afero.NewOsFs()), fileHandler: file.NewHandler(afero.NewOsFs()),
iamConfig: &cloudcmd.IAMConfigOptions{ providerCreator: providerCreator,
TFWorkspace: constants.TerraformIAMWorkingDir, provider: provider,
TFLogLevel: logLevel, }
}, if err := iamCreator.flags.parse(cmd.Flags()); err != nil {
}, nil return err
}
return iamCreator.create(cmd.Context())
} }
// iamCreator is the iamCreator for the iam create command. // iamCreator is the iamCreator for the iam create command.
@ -208,24 +110,18 @@ type iamCreator struct {
providerCreator providerIAMCreator providerCreator providerIAMCreator
iamConfig *cloudcmd.IAMConfigOptions iamConfig *cloudcmd.IAMConfigOptions
log debugLog log debugLog
pf pathprefix.PathPrefixer flags iamCreateFlags
} }
// create IAM configuration on the iamCreator's cloud provider. // create IAM configuration on the iamCreator's cloud provider.
func (c *iamCreator) create(ctx context.Context) error { func (c *iamCreator) create(ctx context.Context) error {
flags, err := c.parseFlagsAndSetupConfig()
if err != nil {
return err
}
c.log.Debugf("Using flags: %+v", flags)
if err := c.checkWorkingDir(); err != nil { if err := c.checkWorkingDir(); err != nil {
return err return err
} }
if !flags.yesFlag { if !c.flags.yes {
c.cmd.Printf("The following IAM configuration will be created:\n\n") c.cmd.Printf("The following IAM configuration will be created:\n\n")
c.providerCreator.printConfirmValues(c.cmd, flags) c.providerCreator.printConfirmValues(c.cmd)
ok, err := askToConfirm(c.cmd, "Do you want to create the configuration?") ok, err := askToConfirm(c.cmd, "Do you want to create the configuration?")
if err != nil { if err != nil {
return err return err
@ -237,19 +133,22 @@ func (c *iamCreator) create(ctx context.Context) error {
} }
var conf config.Config var conf config.Config
if flags.updateConfig { if c.flags.updateConfig {
c.log.Debugf("Parsing config %s", c.pf.PrefixPrintablePath(constants.ConfigFilename)) c.log.Debugf("Parsing config %s", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
if err = c.fileHandler.ReadYAML(constants.ConfigFilename, &conf); err != nil { if err := c.fileHandler.ReadYAML(constants.ConfigFilename, &conf); err != nil {
return fmt.Errorf("error reading the configuration file: %w", err) return fmt.Errorf("error reading the configuration file: %w", err)
} }
if err := validateConfigWithFlagCompatibility(c.provider, conf, flags); err != nil { if err := c.providerCreator.validateConfigWithFlagCompatibility(conf); err != nil {
return err return err
} }
c.cmd.Printf("The configuration file %q will be automatically updated with the IAM values and zone/region information.\n", c.pf.PrefixPrintablePath(constants.ConfigFilename)) c.cmd.Printf("The configuration file %q will be automatically updated with the IAM values and zone/region information.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
} }
iamConfig := c.providerCreator.getIAMConfigOptions()
iamConfig.TFWorkspace = constants.TerraformIAMWorkingDir
iamConfig.TFLogLevel = c.flags.tfLogLevel
c.spinner.Start("Creating", false) c.spinner.Start("Creating", false)
iamFile, err := c.creator.Create(ctx, c.provider, c.iamConfig) iamFile, err := c.creator.Create(ctx, c.provider, iamConfig)
c.spinner.Stop() c.spinner.Stop()
if err != nil { if err != nil {
return err return err
@ -262,321 +161,47 @@ func (c *iamCreator) create(ctx context.Context) error {
return err return err
} }
if flags.updateConfig { if c.flags.updateConfig {
c.log.Debugf("Writing IAM configuration to %s", c.pf.PrefixPrintablePath(constants.ConfigFilename)) c.log.Debugf("Writing IAM configuration to %s", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
c.providerCreator.writeOutputValuesToConfig(&conf, flags, iamFile) c.providerCreator.writeOutputValuesToConfig(&conf, iamFile)
if err := c.fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil { if err := c.fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil {
return err return err
} }
c.cmd.Printf("Your IAM configuration was created and filled into %s successfully.\n", c.pf.PrefixPrintablePath(constants.ConfigFilename)) c.cmd.Printf("Your IAM configuration was created and filled into %s successfully.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
return nil return nil
} }
c.providerCreator.printOutputValues(c.cmd, flags, iamFile) c.providerCreator.printOutputValues(c.cmd, iamFile)
c.cmd.Println("Your IAM configuration was created successfully. Please fill the above values into your configuration file.") c.cmd.Println("Your IAM configuration was created successfully. Please fill the above values into your configuration file.")
return nil return nil
} }
// parseFlagsAndSetupConfig parses the flags of the iam create command and fills the values into the IAM config (output values of the command).
func (c *iamCreator) parseFlagsAndSetupConfig() (iamFlags, error) {
workDir, err := c.cmd.Flags().GetString("workspace")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing config string: %w", err)
}
c.pf = pathprefix.New(workDir)
yesFlag, err := c.cmd.Flags().GetBool("yes")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
updateConfig, err := c.cmd.Flags().GetBool("update-config")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing update-config bool: %w", err)
}
flags := iamFlags{
yesFlag: yesFlag,
updateConfig: updateConfig,
}
flags, err = c.providerCreator.parseFlagsAndSetupConfig(c.cmd, flags, c.iamConfig)
if err != nil {
return iamFlags{}, fmt.Errorf("parsing provider-specific value: %w", err)
}
return flags, nil
}
// checkWorkingDir checks if the current working directory already contains a Terraform dir. // checkWorkingDir checks if the current working directory already contains a Terraform dir.
func (c *iamCreator) checkWorkingDir() error { func (c *iamCreator) checkWorkingDir() error {
if _, err := c.fileHandler.Stat(constants.TerraformIAMWorkingDir); err == nil { if _, err := c.fileHandler.Stat(constants.TerraformIAMWorkingDir); err == nil {
return fmt.Errorf("the current working directory already contains the Terraform workspace directory %q. Please run the command in a different directory or destroy the existing workspace", c.pf.PrefixPrintablePath(constants.TerraformIAMWorkingDir)) return fmt.Errorf(
"the current working directory already contains the Terraform workspace directory %q. Please run the command in a different directory or destroy the existing workspace",
c.flags.pathPrefixer.PrefixPrintablePath(constants.TerraformIAMWorkingDir),
)
} }
return nil return nil
} }
// iamFlags contains the parsed flags of the iam create command, including the parsed flags of the selected cloud provider.
type iamFlags struct {
aws awsFlags
azure azureFlags
gcp gcpFlags
yesFlag bool
updateConfig bool
}
// awsFlags contains the parsed flags of the iam create aws command.
type awsFlags struct {
prefix string
region string
zone string
}
// azureFlags contains the parsed flags of the iam create azure command.
type azureFlags struct {
region string
resourceGroup string
servicePrincipal string
}
// gcpFlags contains the parsed flags of the iam create gcp command.
type gcpFlags struct {
serviceAccountID string
zone string
region string
projectID string
}
// providerIAMCreator is an interface for the IAM actions of different cloud providers. // providerIAMCreator is an interface for the IAM actions of different cloud providers.
type providerIAMCreator interface { type providerIAMCreator interface {
// printConfirmValues prints the values that will be created on the cloud provider and need to be confirmed by the user. // printConfirmValues prints the values that will be created on the cloud provider and need to be confirmed by the user.
printConfirmValues(cmd *cobra.Command, flags iamFlags) printConfirmValues(cmd *cobra.Command)
// printOutputValues prints the values that were created on the cloud provider. // printOutputValues prints the values that were created on the cloud provider.
printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput) printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput)
// writeOutputValuesToConfig writes the output values of the IAM creation to the constellation config file. // writeOutputValuesToConfig writes the output values of the IAM creation to the constellation config file.
writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput) writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput)
// parseFlagsAndSetupConfig parses the provider-specific flags and fills the values into the IAM config (output values of the command). // getIAMConfigOptions sets up the IAM values required to create the IAM configuration.
parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) getIAMConfigOptions() *cloudcmd.IAMConfigOptions
// parseAndWriteIDFile parses the GCP service account key and writes it to a keyfile. It is only implemented for GCP. // parseAndWriteIDFile parses the GCP service account key and writes it to a keyfile. It is only implemented for GCP.
parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error
}
// awsIAMCreator implements the providerIAMCreator interface for AWS. validateConfigWithFlagCompatibility(config.Config) error
type awsIAMCreator struct{}
func (c *awsIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
prefix, err := cmd.Flags().GetString("prefix")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing prefix string: %w", err)
}
if len(prefix) > 36 {
return iamFlags{}, fmt.Errorf("prefix must be 36 characters or less")
}
zone, err := cmd.Flags().GetString("zone")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing zone string: %w", err)
}
if !config.ValidateAWSZone(zone) {
return iamFlags{}, fmt.Errorf("invalid AWS zone. To find a valid zone, please refer to our docs and https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-availability-zones")
}
// Infer region from zone.
region := zone[:len(zone)-1]
if !config.ValidateAWSRegion(region) {
return iamFlags{}, fmt.Errorf("invalid AWS region: %s", region)
}
flags.aws = awsFlags{
prefix: prefix,
zone: zone,
region: region,
}
// Setup IAM config.
iamConfig.AWS = cloudcmd.AWSIAMConfig{
Region: flags.aws.region,
Prefix: flags.aws.prefix,
}
return flags, nil
}
func (c *awsIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Region:\t\t%s\n", flags.aws.region)
cmd.Printf("Name Prefix:\t%s\n\n", flags.aws.prefix)
}
func (c *awsIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput) {
cmd.Printf("region:\t\t\t%s\n", flags.aws.region)
cmd.Printf("zone:\t\t\t%s\n", flags.aws.zone)
cmd.Printf("iamProfileControlPlane:\t%s\n", iamFile.AWSOutput.ControlPlaneInstanceProfile)
cmd.Printf("iamProfileWorkerNodes:\t%s\n\n", iamFile.AWSOutput.WorkerNodeInstanceProfile)
}
func (c *awsIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput) {
conf.Provider.AWS.Region = flags.aws.region
conf.Provider.AWS.Zone = flags.aws.zone
conf.Provider.AWS.IAMProfileControlPlane = iamFile.AWSOutput.ControlPlaneInstanceProfile
conf.Provider.AWS.IAMProfileWorkerNodes = iamFile.AWSOutput.WorkerNodeInstanceProfile
for groupName, group := range conf.NodeGroups {
group.Zone = flags.aws.zone
conf.NodeGroups[groupName] = group
}
}
func (c *awsIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
// azureIAMCreator implements the providerIAMCreator interface for Azure.
type azureIAMCreator struct{}
func (c *azureIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
region, err := cmd.Flags().GetString("region")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing region string: %w", err)
}
resourceGroup, err := cmd.Flags().GetString("resourceGroup")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing resourceGroup string: %w", err)
}
servicePrincipal, err := cmd.Flags().GetString("servicePrincipal")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing servicePrincipal string: %w", err)
}
flags.azure = azureFlags{
region: region,
resourceGroup: resourceGroup,
servicePrincipal: servicePrincipal,
}
// Setup IAM config.
iamConfig.Azure = cloudcmd.AzureIAMConfig{
Region: flags.azure.region,
ResourceGroup: flags.azure.resourceGroup,
ServicePrincipal: flags.azure.servicePrincipal,
}
return flags, nil
}
func (c *azureIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Region:\t\t\t%s\n", flags.azure.region)
cmd.Printf("Resource Group:\t\t%s\n", flags.azure.resourceGroup)
cmd.Printf("Service Principal:\t%s\n\n", flags.azure.servicePrincipal)
}
func (c *azureIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput) {
cmd.Printf("subscription:\t\t%s\n", iamFile.AzureOutput.SubscriptionID)
cmd.Printf("tenant:\t\t\t%s\n", iamFile.AzureOutput.TenantID)
cmd.Printf("location:\t\t%s\n", flags.azure.region)
cmd.Printf("resourceGroup:\t\t%s\n", flags.azure.resourceGroup)
cmd.Printf("userAssignedIdentity:\t%s\n", iamFile.AzureOutput.UAMIID)
}
func (c *azureIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput) {
conf.Provider.Azure.SubscriptionID = iamFile.AzureOutput.SubscriptionID
conf.Provider.Azure.TenantID = iamFile.AzureOutput.TenantID
conf.Provider.Azure.Location = flags.azure.region
conf.Provider.Azure.ResourceGroup = flags.azure.resourceGroup
conf.Provider.Azure.UserAssignedIdentity = iamFile.AzureOutput.UAMIID
}
func (c *azureIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
// gcpIAMCreator implements the providerIAMCreator interface for GCP.
type gcpIAMCreator struct {
pf pathprefix.PathPrefixer
}
func (c *gcpIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
zone, err := cmd.Flags().GetString("zone")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing zone string: %w", err)
}
if !zoneRegex.MatchString(zone) {
return iamFlags{}, fmt.Errorf("invalid zone string: %s", zone)
}
// Infer region from zone.
zoneParts := strings.Split(zone, "-")
region := fmt.Sprintf("%s-%s", zoneParts[0], zoneParts[1])
if !regionRegex.MatchString(region) {
return iamFlags{}, fmt.Errorf("invalid region string: %s", region)
}
projectID, err := cmd.Flags().GetString("projectID")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing projectID string: %w", err)
}
if !gcpIDRegex.MatchString(projectID) {
return iamFlags{}, fmt.Errorf("projectID %q doesn't match %s", projectID, gcpIDRegex)
}
serviceAccID, err := cmd.Flags().GetString("serviceAccountID")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing serviceAccountID string: %w", err)
}
if !gcpIDRegex.MatchString(serviceAccID) {
return iamFlags{}, fmt.Errorf("serviceAccountID %q doesn't match %s", serviceAccID, gcpIDRegex)
}
flags.gcp = gcpFlags{
zone: zone,
region: region,
projectID: projectID,
serviceAccountID: serviceAccID,
}
// Setup IAM config.
iamConfig.GCP = cloudcmd.GCPIAMConfig{
Zone: flags.gcp.zone,
Region: flags.gcp.region,
ProjectID: flags.gcp.projectID,
ServiceAccountID: flags.gcp.serviceAccountID,
}
return flags, nil
}
func (c *gcpIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Project ID:\t\t%s\n", flags.gcp.projectID)
cmd.Printf("Service Account ID:\t%s\n", flags.gcp.serviceAccountID)
cmd.Printf("Region:\t\t\t%s\n", flags.gcp.region)
cmd.Printf("Zone:\t\t\t%s\n\n", flags.gcp.zone)
}
func (c *gcpIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, _ cloudcmd.IAMOutput) {
cmd.Printf("projectID:\t\t%s\n", flags.gcp.projectID)
cmd.Printf("region:\t\t\t%s\n", flags.gcp.region)
cmd.Printf("zone:\t\t\t%s\n", flags.gcp.zone)
cmd.Printf("serviceAccountKeyPath:\t%s\n\n", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
}
func (c *gcpIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, _ cloudcmd.IAMOutput) {
conf.Provider.GCP.Project = flags.gcp.projectID
conf.Provider.GCP.ServiceAccountKeyPath = constants.GCPServiceAccountKeyFilename // File was created in workspace, so only the filename is needed.
conf.Provider.GCP.Region = flags.gcp.region
conf.Provider.GCP.Zone = flags.gcp.zone
for groupName, group := range conf.NodeGroups {
group.Zone = flags.gcp.zone
conf.NodeGroups[groupName] = group
}
}
func (c *gcpIAMCreator) parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error {
// GCP needs to write the service account key to a file.
tmpOut, err := parseIDFile(iamFile.GCPOutput.ServiceAccountKey)
if err != nil {
return err
}
return fileHandler.WriteJSON(constants.GCPServiceAccountKeyFilename, tmpOut, file.OptNone)
} }
// parseIDFile parses the given base64 encoded JSON string of the GCP service account key and returns a map. // parseIDFile parses the given base64 encoded JSON string of the GCP service account key and returns a map.
@ -594,30 +219,17 @@ func parseIDFile(serviceAccountKeyBase64 string) (map[string]string, error) {
} }
// validateConfigWithFlagCompatibility checks if the config is compatible with the flags. // validateConfigWithFlagCompatibility checks if the config is compatible with the flags.
func validateConfigWithFlagCompatibility(iamProvider cloudprovider.Provider, cfg config.Config, flags iamFlags) error { func validateConfigWithFlagCompatibility(iamProvider cloudprovider.Provider, cfg config.Config, zone string) error {
if !cfg.HasProvider(iamProvider) { if !cfg.HasProvider(iamProvider) {
return fmt.Errorf("cloud provider from the the configuration file differs from the one provided via the command %q", iamProvider) return fmt.Errorf("cloud provider from the the configuration file differs from the one provided via the command %q", iamProvider)
} }
return checkIfCfgZoneAndFlagZoneDiffer(iamProvider, flags, cfg) return checkIfCfgZoneAndFlagZoneDiffer(zone, cfg)
} }
func checkIfCfgZoneAndFlagZoneDiffer(iamProvider cloudprovider.Provider, flags iamFlags, cfg config.Config) error { func checkIfCfgZoneAndFlagZoneDiffer(zone string, cfg config.Config) error {
flagZone := flagZoneOrAzRegion(iamProvider, flags)
configZone := cfg.GetZone() configZone := cfg.GetZone()
if configZone != "" && flagZone != configZone { if configZone != "" && zone != configZone {
return fmt.Errorf("zone/region from the configuration file %q differs from the one provided via flags %q", configZone, flagZone) return fmt.Errorf("zone/region from the configuration file %q differs from the one provided via flags %q", configZone, zone)
} }
return nil return nil
} }
func flagZoneOrAzRegion(provider cloudprovider.Provider, flags iamFlags) string {
switch provider {
case cloudprovider.AWS:
return flags.aws.zone
case cloudprovider.Azure:
return flags.azure.region
case cloudprovider.GCP:
return flags.gcp.zone
}
return ""
}

View file

@ -82,7 +82,6 @@ func TestIAMCreateAWS(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator creator *stubIAMCreator
provider cloudprovider.Provider
zoneFlag string zoneFlag string
prefixFlag string prefixFlag string
yesFlag bool yesFlag bool
@ -96,26 +95,14 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws": { "iam create aws": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
existingConfigFiles: []string{constants.ConfigFilename}, existingConfigFiles: []string{constants.ConfigFilename},
}, },
"iam create aws fails when --zone has no availability zone": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1",
prefixFlag: "test",
yesFlag: true,
existingConfigFiles: []string{constants.ConfigFilename},
wantErr: true,
},
"iam create aws --update-config": { "iam create aws --update-config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -130,7 +117,6 @@ func TestIAMCreateAWS(t *testing.T) {
return *cfg return *cfg
}()), }()),
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1a", zoneFlag: "us-east-1a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -141,7 +127,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws --update-config fails when config has different provider": { "iam create aws --update-config fails when config has different provider": {
setupFs: createFSWithConfig(*createConfig(cloudprovider.GCP)), setupFs: createFSWithConfig(*createConfig(cloudprovider.GCP)),
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1a", zoneFlag: "us-east-1a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -152,7 +137,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws no config": { "iam create aws no config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -160,7 +144,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws existing terraform dir": { "iam create aws existing terraform dir": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -170,7 +153,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive": { "interactive": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
stdin: "yes\n", stdin: "yes\n",
@ -178,7 +160,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive update config": { "interactive update config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
stdin: "yes\n", stdin: "yes\n",
@ -188,7 +169,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive abort": { "interactive abort": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
stdin: "no\n", stdin: "no\n",
@ -197,7 +177,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive update config abort": { "interactive update config abort": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
stdin: "no\n", stdin: "no\n",
@ -205,19 +184,9 @@ func TestIAMCreateAWS(t *testing.T) {
wantAbort: true, wantAbort: true,
existingConfigFiles: []string{constants.ConfigFilename}, existingConfigFiles: []string{constants.ConfigFilename},
}, },
"invalid zone": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-west",
prefixFlag: "test",
yesFlag: true,
wantErr: true,
},
"unwritable fs": { "unwritable fs": {
setupFs: readOnlyFs, setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a", zoneFlag: "us-east-2a",
prefixFlag: "test", prefixFlag: "test",
yesFlag: true, yesFlag: true,
@ -236,27 +205,7 @@ func TestIAMCreateAWS(t *testing.T) {
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.AWS, tc.existingConfigFiles, tc.existingDirs))
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.zoneFlag != "" {
require.NoError(cmd.Flags().Set("zone", tc.zoneFlag))
}
if tc.prefixFlag != "" {
require.NoError(cmd.Flags().Set("prefix", tc.prefixFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{ iamCreator := &iamCreator{
cmd: cmd, cmd: cmd,
@ -265,8 +214,17 @@ func TestIAMCreateAWS(t *testing.T) {
creator: tc.creator, creator: tc.creator,
fileHandler: fileHandler, fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{}, iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider, provider: cloudprovider.AWS,
providerCreator: &awsIAMCreator{}, flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &awsIAMCreator{
flags: awsIAMCreateFlags{
zone: tc.zoneFlag,
prefix: tc.prefixFlag,
},
},
} }
err := iamCreator.create(cmd.Context()) err := iamCreator.create(cmd.Context())
@ -315,7 +273,6 @@ func TestIAMCreateAzure(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator creator *stubIAMCreator
provider cloudprovider.Provider
regionFlag string regionFlag string
servicePrincipalFlag string servicePrincipalFlag string
resourceGroupFlag string resourceGroupFlag string
@ -330,7 +287,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure": { "iam create azure": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -339,7 +295,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure with existing config": { "iam create azure with existing config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -349,7 +304,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure --update-config": { "iam create azure --update-config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -360,7 +314,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure existing terraform dir": { "iam create azure existing terraform dir": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -371,7 +324,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive": { "interactive": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -380,7 +332,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive update config": { "interactive update config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -391,7 +342,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive abort": { "interactive abort": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -401,7 +351,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive update config abort": { "interactive update config abort": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -413,7 +362,6 @@ func TestIAMCreateAzure(t *testing.T) {
"unwritable fs": { "unwritable fs": {
setupFs: readOnlyFs, setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus", regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp", servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg", resourceGroupFlag: "constell-test-rg",
@ -433,30 +381,7 @@ func TestIAMCreateAzure(t *testing.T) {
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.Azure, tc.existingConfigFiles, tc.existingDirs))
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.regionFlag != "" {
require.NoError(cmd.Flags().Set("region", tc.regionFlag))
}
if tc.resourceGroupFlag != "" {
require.NoError(cmd.Flags().Set("resourceGroup", tc.resourceGroupFlag))
}
if tc.servicePrincipalFlag != "" {
require.NoError(cmd.Flags().Set("servicePrincipal", tc.servicePrincipalFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{ iamCreator := &iamCreator{
cmd: cmd, cmd: cmd,
@ -465,8 +390,18 @@ func TestIAMCreateAzure(t *testing.T) {
creator: tc.creator, creator: tc.creator,
fileHandler: fileHandler, fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{}, iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider, provider: cloudprovider.Azure,
providerCreator: &azureIAMCreator{}, flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &azureIAMCreator{
flags: azureIAMCreateFlags{
region: tc.regionFlag,
resourceGroup: tc.resourceGroupFlag,
servicePrincipal: tc.servicePrincipalFlag,
},
},
} }
err := iamCreator.create(cmd.Context()) err := iamCreator.create(cmd.Context())
@ -519,7 +454,6 @@ func TestIAMCreateGCP(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator creator *stubIAMCreator
provider cloudprovider.Provider
zoneFlag string zoneFlag string
serviceAccountIDFlag string serviceAccountIDFlag string
projectIDFlag string projectIDFlag string
@ -534,7 +468,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp": { "iam create gcp": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -543,7 +476,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp with existing config": { "iam create gcp with existing config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -553,7 +485,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp --update-config": { "iam create gcp --update-config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -564,7 +495,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp existing terraform dir": { "iam create gcp existing terraform dir": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -573,18 +503,9 @@ func TestIAMCreateGCP(t *testing.T) {
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,
}, },
"iam create gcp invalid flags": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "-a",
yesFlag: true,
wantErr: true,
},
"iam create gcp invalid b64": { "iam create gcp invalid b64": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: invalidIAMIDFile}, creator: &stubIAMCreator{id: invalidIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -594,7 +515,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive": { "interactive": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -603,7 +523,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive update config": { "interactive update config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -614,7 +533,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive abort": { "interactive abort": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -624,7 +542,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive abort update config": { "interactive abort update config": {
setupFs: defaultFs, setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -636,7 +553,6 @@ func TestIAMCreateGCP(t *testing.T) {
"unwritable fs": { "unwritable fs": {
setupFs: readOnlyFs, setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile}, creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a", zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test", serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234", projectIDFlag: "constell-1234",
@ -656,30 +572,7 @@ func TestIAMCreateGCP(t *testing.T) {
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.GCP, tc.existingConfigFiles, tc.existingDirs))
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.zoneFlag != "" {
require.NoError(cmd.Flags().Set("zone", tc.zoneFlag))
}
if tc.serviceAccountIDFlag != "" {
require.NoError(cmd.Flags().Set("serviceAccountID", tc.serviceAccountIDFlag))
}
if tc.projectIDFlag != "" {
require.NoError(cmd.Flags().Set("projectID", tc.projectIDFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{ iamCreator := &iamCreator{
cmd: cmd, cmd: cmd,
@ -688,8 +581,18 @@ func TestIAMCreateGCP(t *testing.T) {
creator: tc.creator, creator: tc.creator,
fileHandler: fileHandler, fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{}, iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider, provider: cloudprovider.GCP,
providerCreator: &gcpIAMCreator{}, flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &gcpIAMCreator{
flags: gcpIAMCreateFlags{
zone: tc.zoneFlag,
serviceAccountID: tc.serviceAccountIDFlag,
projectID: tc.projectIDFlag,
},
},
} }
err := iamCreator.create(cmd.Context()) err := iamCreator.create(cmd.Context())
@ -724,7 +627,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
iamProvider cloudprovider.Provider iamProvider cloudprovider.Provider
cfg config.Config cfg config.Config
flags iamFlags zone string
wantErr bool wantErr bool
}{ }{
"AWS valid when cfg.zone == flag.zone": { "AWS valid when cfg.zone == flag.zone": {
@ -734,21 +637,13 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.AWS.Zone = "europe-west-1a" cfg.Provider.AWS.Zone = "europe-west-1a"
return *cfg return *cfg
}(), }(),
flags: iamFlags{
aws: awsFlags{
zone: "europe-west-1a", zone: "europe-west-1a",
}, },
},
},
"AWS valid when cfg.zone not set": { "AWS valid when cfg.zone not set": {
iamProvider: cloudprovider.AWS, iamProvider: cloudprovider.AWS,
cfg: *createConfig(cloudprovider.AWS), cfg: *createConfig(cloudprovider.AWS),
flags: iamFlags{
aws: awsFlags{
zone: "europe-west-1a", zone: "europe-west-1a",
}, },
},
},
"GCP invalid when cfg.zone != flag.zone": { "GCP invalid when cfg.zone != flag.zone": {
iamProvider: cloudprovider.GCP, iamProvider: cloudprovider.GCP,
cfg: func() config.Config { cfg: func() config.Config {
@ -756,11 +651,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.GCP.Zone = "europe-west-1a" cfg.Provider.GCP.Zone = "europe-west-1a"
return *cfg return *cfg
}(), }(),
flags: iamFlags{
aws: awsFlags{
zone: "us-west-1a", zone: "us-west-1a",
},
},
wantErr: true, wantErr: true,
}, },
"Azure invalid when cfg.zone != flag.zone": { "Azure invalid when cfg.zone != flag.zone": {
@ -770,11 +661,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.Azure.Location = "europe-west-1a" cfg.Provider.Azure.Location = "europe-west-1a"
return *cfg return *cfg
}(), }(),
flags: iamFlags{
aws: awsFlags{
zone: "us-west-1a", zone: "us-west-1a",
},
},
wantErr: true, wantErr: true,
}, },
"GCP invalid when cfg.provider different from iam provider": { "GCP invalid when cfg.provider different from iam provider": {
@ -786,7 +673,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
err := validateConfigWithFlagCompatibility(tc.iamProvider, tc.cfg, tc.flags) err := validateConfigWithFlagCompatibility(tc.iamProvider, tc.cfg, tc.zone)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return

View file

@ -0,0 +1,121 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// newIAMCreateAWSCmd returns a new cobra.Command for the iam create aws command.
func newIAMCreateAWSCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "aws",
Short: "Create IAM configuration on AWS for your Constellation cluster",
Long: "Create IAM configuration on AWS for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateAWS,
}
cmd.Flags().String("prefix", "", "name prefix for all resources (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "prefix"))
cmd.Flags().String("zone", "", "AWS availability zone the resources will be created in, e.g., us-east-2a (required)\n"+
"See the Constellation docs for a list of currently supported regions.")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
return cmd
}
func runIAMCreateAWS(cmd *cobra.Command, _ []string) error {
creator := &awsIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.AWS)
}
// awsIAMCreateFlags contains the parsed flags of the iam create aws command.
type awsIAMCreateFlags struct {
prefix string
region string
zone string
}
func (f *awsIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
f.prefix, err = flags.GetString("prefix")
if err != nil {
return fmt.Errorf("getting 'prefix' flag: %w", err)
}
if len(f.prefix) > 36 {
return errors.New("prefix must be 36 characters or less")
}
f.zone, err = flags.GetString("zone")
if err != nil {
return fmt.Errorf("getting 'zone' flag: %w", err)
}
if !config.ValidateAWSZone(f.zone) {
return errors.New("invalid AWS zone. To find a valid zone, please refer to our docs and https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-availability-zones")
}
// Infer region from zone.
f.region = f.zone[:len(f.zone)-1]
if !config.ValidateAWSRegion(f.region) {
return fmt.Errorf("invalid AWS region: %s", f.region)
}
return nil
}
// awsIAMCreator implements the providerIAMCreator interface for AWS.
type awsIAMCreator struct {
flags awsIAMCreateFlags
}
func (c *awsIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
AWS: cloudcmd.AWSIAMConfig{
Region: c.flags.region,
Prefix: c.flags.prefix,
},
}
}
func (c *awsIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Region:\t\t%s\n", c.flags.region)
cmd.Printf("Name Prefix:\t%s\n\n", c.flags.prefix)
}
func (c *awsIAMCreator) printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput) {
cmd.Printf("region:\t\t\t%s\n", c.flags.region)
cmd.Printf("zone:\t\t\t%s\n", c.flags.zone)
cmd.Printf("iamProfileControlPlane:\t%s\n", iamFile.AWSOutput.ControlPlaneInstanceProfile)
cmd.Printf("iamProfileWorkerNodes:\t%s\n\n", iamFile.AWSOutput.WorkerNodeInstanceProfile)
}
func (c *awsIAMCreator) writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput) {
conf.Provider.AWS.Region = c.flags.region
conf.Provider.AWS.Zone = c.flags.zone
conf.Provider.AWS.IAMProfileControlPlane = iamFile.AWSOutput.ControlPlaneInstanceProfile
conf.Provider.AWS.IAMProfileWorkerNodes = iamFile.AWSOutput.WorkerNodeInstanceProfile
for groupName, group := range conf.NodeGroups {
group.Zone = c.flags.zone
conf.NodeGroups[groupName] = group
}
}
func (c *awsIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
func (c *awsIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.AWS, conf, c.flags.zone)
}

View file

@ -0,0 +1,113 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// newIAMCreateAzureCmd returns a new cobra.Command for the iam create azure command.
func newIAMCreateAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure",
Short: "Create IAM configuration on Microsoft Azure for your Constellation cluster",
Long: "Create IAM configuration on Microsoft Azure for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateAzure,
}
cmd.Flags().String("resourceGroup", "", "name prefix of the two resource groups your cluster / IAM resources will be created in (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "resourceGroup"))
cmd.Flags().String("region", "", "region the resources will be created in, e.g., westus (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "region"))
cmd.Flags().String("servicePrincipal", "", "name of the service principal that will be created (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "servicePrincipal"))
return cmd
}
func runIAMCreateAzure(cmd *cobra.Command, _ []string) error {
creator := &azureIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.Azure)
}
// azureIAMCreateFlags contains the parsed flags of the iam create azure command.
type azureIAMCreateFlags struct {
region string
resourceGroup string
servicePrincipal string
}
func (f *azureIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
f.region, err = flags.GetString("region")
if err != nil {
return fmt.Errorf("getting 'region' flag: %w", err)
}
f.resourceGroup, err = flags.GetString("resourceGroup")
if err != nil {
return fmt.Errorf("getting 'resourceGroup' flag: %w", err)
}
f.servicePrincipal, err = flags.GetString("servicePrincipal")
if err != nil {
return fmt.Errorf("getting 'servicePrincipal' flag: %w", err)
}
return nil
}
// azureIAMCreator implements the providerIAMCreator interface for Azure.
type azureIAMCreator struct {
flags azureIAMCreateFlags
}
func (c *azureIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
Azure: cloudcmd.AzureIAMConfig{
Region: c.flags.region,
ResourceGroup: c.flags.resourceGroup,
ServicePrincipal: c.flags.servicePrincipal,
},
}
}
func (c *azureIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Region:\t\t\t%s\n", c.flags.region)
cmd.Printf("Resource Group:\t\t%s\n", c.flags.resourceGroup)
cmd.Printf("Service Principal:\t%s\n\n", c.flags.servicePrincipal)
}
func (c *azureIAMCreator) printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput) {
cmd.Printf("subscription:\t\t%s\n", iamFile.AzureOutput.SubscriptionID)
cmd.Printf("tenant:\t\t\t%s\n", iamFile.AzureOutput.TenantID)
cmd.Printf("location:\t\t%s\n", c.flags.region)
cmd.Printf("resourceGroup:\t\t%s\n", c.flags.resourceGroup)
cmd.Printf("userAssignedIdentity:\t%s\n", iamFile.AzureOutput.UAMIID)
}
func (c *azureIAMCreator) writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput) {
conf.Provider.Azure.SubscriptionID = iamFile.AzureOutput.SubscriptionID
conf.Provider.Azure.TenantID = iamFile.AzureOutput.TenantID
conf.Provider.Azure.Location = c.flags.region
conf.Provider.Azure.ResourceGroup = c.flags.resourceGroup
conf.Provider.Azure.UserAssignedIdentity = iamFile.AzureOutput.UAMIID
}
func (c *azureIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
func (c *azureIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.Azure, conf, c.flags.region)
}

View file

@ -0,0 +1,153 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// NewIAMCreateGCPCmd returns a new cobra.Command for the iam create gcp command.
func newIAMCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create IAM configuration on GCP for your Constellation cluster",
Long: "Create IAM configuration on GCP for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateGCP,
}
cmd.Flags().String("zone", "", "GCP zone the cluster will be deployed in (required)\n"+
"Find a list of available zones here: https://cloud.google.com/compute/docs/regions-zones#available")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
cmd.Flags().String("serviceAccountID", "", "ID for the service account that will be created (required)\n"+
"Must be 6 to 30 lowercase letters, digits, or hyphens.")
must(cobra.MarkFlagRequired(cmd.Flags(), "serviceAccountID"))
cmd.Flags().String("projectID", "", "ID of the GCP project the configuration will be created in (required)\n"+
"Find it on the welcome screen of your project: https://console.cloud.google.com/welcome")
must(cobra.MarkFlagRequired(cmd.Flags(), "projectID"))
return cmd
}
func runIAMCreateGCP(cmd *cobra.Command, _ []string) error {
creator := &gcpIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.GCP)
}
// gcpIAMCreateFlags contains the parsed flags of the iam create gcp command.
type gcpIAMCreateFlags struct {
rootFlags
serviceAccountID string
zone string
region string
projectID string
}
func (f *gcpIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
if err = f.rootFlags.parse(flags); err != nil {
return err
}
f.zone, err = flags.GetString("zone")
if err != nil {
return fmt.Errorf("getting 'zone' flag: %w", err)
}
if !zoneRegex.MatchString(f.zone) {
return fmt.Errorf("invalid zone string: %s", f.zone)
}
// Infer region from zone.
zoneParts := strings.Split(f.zone, "-")
f.region = fmt.Sprintf("%s-%s", zoneParts[0], zoneParts[1])
if !regionRegex.MatchString(f.region) {
return fmt.Errorf("invalid region string: %s", f.region)
}
f.projectID, err = flags.GetString("projectID")
if err != nil {
return fmt.Errorf("getting 'projectID' flag: %w", err)
}
if !gcpIDRegex.MatchString(f.projectID) {
return fmt.Errorf("projectID %q doesn't match %s", f.projectID, gcpIDRegex)
}
f.serviceAccountID, err = flags.GetString("serviceAccountID")
if err != nil {
return fmt.Errorf("getting 'serviceAccountID' flag: %w", err)
}
if !gcpIDRegex.MatchString(f.serviceAccountID) {
return fmt.Errorf("serviceAccountID %q doesn't match %s", f.serviceAccountID, gcpIDRegex)
}
return nil
}
// gcpIAMCreator implements the providerIAMCreator interface for GCP.
type gcpIAMCreator struct {
flags gcpIAMCreateFlags
}
func (c *gcpIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
GCP: cloudcmd.GCPIAMConfig{
Zone: c.flags.zone,
Region: c.flags.region,
ProjectID: c.flags.projectID,
ServiceAccountID: c.flags.serviceAccountID,
},
}
}
func (c *gcpIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Project ID:\t\t%s\n", c.flags.projectID)
cmd.Printf("Service Account ID:\t%s\n", c.flags.serviceAccountID)
cmd.Printf("Region:\t\t\t%s\n", c.flags.region)
cmd.Printf("Zone:\t\t\t%s\n\n", c.flags.zone)
}
func (c *gcpIAMCreator) printOutputValues(cmd *cobra.Command, _ cloudcmd.IAMOutput) {
cmd.Printf("projectID:\t\t%s\n", c.flags.projectID)
cmd.Printf("region:\t\t\t%s\n", c.flags.region)
cmd.Printf("zone:\t\t\t%s\n", c.flags.zone)
cmd.Printf("serviceAccountKeyPath:\t%s\n\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
}
func (c *gcpIAMCreator) writeOutputValuesToConfig(conf *config.Config, _ cloudcmd.IAMOutput) {
conf.Provider.GCP.Project = c.flags.projectID
conf.Provider.GCP.ServiceAccountKeyPath = constants.GCPServiceAccountKeyFilename // File was created in workspace, so only the filename is needed.
conf.Provider.GCP.Region = c.flags.region
conf.Provider.GCP.Zone = c.flags.zone
for groupName, group := range conf.NodeGroups {
group.Zone = c.flags.zone
conf.NodeGroups[groupName] = group
}
}
func (c *gcpIAMCreator) parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error {
// GCP needs to write the service account key to a file.
tmpOut, err := parseIDFile(iamFile.GCPOutput.ServiceAccountKey)
if err != nil {
return err
}
return fileHandler.WriteJSON(constants.GCPServiceAccountKeyFilename, tmpOut, file.OptNone)
}
func (c *gcpIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.GCP, conf, c.flags.zone)
}

View file

@ -11,13 +11,12 @@ import (
"os" "os"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
// NewIAMDestroyCmd returns a new cobra.Command for the iam destroy subcommand. // NewIAMDestroyCmd returns a new cobra.Command for the iam destroy subcommand.
@ -35,6 +34,25 @@ func newIAMDestroyCmd() *cobra.Command {
return cmd return cmd
} }
type iamDestroyFlags struct {
rootFlags
yes bool
}
func (f *iamDestroyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
func runIAMDestroy(cmd *cobra.Command, _ []string) error { func runIAMDestroy(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd) log, err := newCLILogger(cmd)
if err != nil { if err != nil {
@ -46,51 +64,47 @@ func runIAMDestroy(cmd *cobra.Command, _ []string) error {
fsHandler := file.NewHandler(afero.NewOsFs()) fsHandler := file.NewHandler(afero.NewOsFs())
c := &destroyCmd{log: log} c := &destroyCmd{log: log}
if err := c.flags.parse(cmd.Flags()); err != nil {
return err
}
return c.iamDestroy(cmd, spinner, destroyer, fsHandler) return c.iamDestroy(cmd, spinner, destroyer, fsHandler)
} }
type destroyCmd struct { type destroyCmd struct {
log debugLog log debugLog
pf pathprefix.PathPrefixer flags iamDestroyFlags
} }
func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destroyer iamDestroyer, fsHandler file.Handler) error { func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destroyer iamDestroyer, fsHandler file.Handler) error {
flags, err := c.parseDestroyFlags(cmd) // check if there is a possibility that the cluster is still running by looking out for specific files
if err != nil { c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
return fmt.Errorf("parsing flags: %w", err) if _, err := fsHandler.Stat(constants.AdminConfFilename); !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
} }
// check if there is a possibility that the cluster is still running by looking out for specific files c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.AdminConfFilename)) if _, err := fsHandler.Stat(constants.StateFilename); !errors.Is(err, os.ErrNotExist) {
_, err = fsHandler.Stat(constants.AdminConfFilename) return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.pf.PrefixPrintablePath(constants.AdminConfFilename))
}
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.StateFilename))
_, err = fsHandler.Stat(constants.StateFilename)
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.pf.PrefixPrintablePath(constants.StateFilename))
} }
gcpFileExists := false gcpFileExists := false
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
_, err = fsHandler.Stat(constants.GCPServiceAccountKeyFilename) if _, err := fsHandler.Stat(constants.GCPServiceAccountKeyFilename); err != nil {
if err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
return err return err
} }
} else { } else {
c.log.Debugf("%q exists", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) c.log.Debugf("%q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
gcpFileExists = true gcpFileExists = true
} }
if !flags.yes { if !c.flags.yes {
// Confirmation // Confirmation
confirmString := "Do you really want to destroy your IAM configuration? Note that this will remove all resources in the resource group." confirmString := "Do you really want to destroy your IAM configuration? Note that this will remove all resources in the resource group."
if gcpFileExists { if gcpFileExists {
confirmString += fmt.Sprintf("\nThis will also delete %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) confirmString += fmt.Sprintf("\nThis will also delete %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
} }
ok, err := askToConfirm(cmd, confirmString) ok, err := askToConfirm(cmd, confirmString)
if err != nil { if err != nil {
@ -103,7 +117,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
} }
if gcpFileExists { if gcpFileExists {
c.log.Debugf("Starting to delete %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) c.log.Debugf("Starting to delete %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
proceed, err := c.deleteGCPServiceAccountKeyFile(cmd, destroyer, fsHandler) proceed, err := c.deleteGCPServiceAccountKeyFile(cmd, destroyer, fsHandler)
if err != nil { if err != nil {
return err return err
@ -118,7 +132,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
spinner.Start("Destroying IAM configuration", false) spinner.Start("Destroying IAM configuration", false)
defer spinner.Stop() defer spinner.Stop()
if err := destroyer.DestroyIAMConfiguration(cmd.Context(), constants.TerraformIAMWorkingDir, flags.tfLogLevel); err != nil { if err := destroyer.DestroyIAMConfiguration(cmd.Context(), constants.TerraformIAMWorkingDir, c.flags.tfLogLevel); err != nil {
return fmt.Errorf("destroying IAM configuration: %w", err) return fmt.Errorf("destroying IAM configuration: %w", err)
} }
@ -130,7 +144,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroyer iamDestroyer, fsHandler file.Handler) (bool, error) { func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroyer iamDestroyer, fsHandler file.Handler) (bool, error) {
var fileSaKey gcpshared.ServiceAccountKey var fileSaKey gcpshared.ServiceAccountKey
c.log.Debugf("Parsing %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) c.log.Debugf("Parsing %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
if err := fsHandler.ReadJSON(constants.GCPServiceAccountKeyFilename, &fileSaKey); err != nil { if err := fsHandler.ReadJSON(constants.GCPServiceAccountKeyFilename, &fileSaKey); err != nil {
return false, err return false, err
} }
@ -143,7 +157,11 @@ func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroye
c.log.Debugf("Checking if keys are the same") c.log.Debugf("Checking if keys are the same")
if tfSaKey != fileSaKey { if tfSaKey != fileSaKey {
cmd.Printf("The key in %q don't match up with your Terraform state. %q will not be deleted.\n", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename), c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) cmd.Printf(
"The key in %q don't match up with your Terraform state. %q will not be deleted.\n",
c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename),
c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename),
)
return true, nil return true, nil
} }
@ -151,42 +169,6 @@ func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroye
return false, err return false, err
} }
c.log.Debugf("Successfully deleted %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename)) c.log.Debugf("Successfully deleted %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
return true, nil return true, nil
} }
type destroyFlags struct {
yes bool
tfLogLevel terraform.LogLevel
}
// parseDestroyFlags parses the flags of the create command.
func (c *destroyCmd) parseDestroyFlags(cmd *cobra.Command) (destroyFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
c.log.Debugf("Yes flag is %t", yes)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing workspace string: %w", err)
}
c.log.Debugf("Workspace set to %q", workDir)
c.pf = pathprefix.New(workDir)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
c.log.Debugf("Terraform logs will be written into %s at level %s", c.pf.PrefixPrintablePath(constants.TerraformWorkingDir), logLevel.String())
return destroyFlags{
tfLogLevel: logLevel,
yes: yes,
}, nil
}

View file

@ -46,52 +46,52 @@ func TestIAMDestroy(t *testing.T) {
iamDestroyer *stubIAMDestroyer iamDestroyer *stubIAMDestroyer
fh file.Handler fh file.Handler
stdin string stdin string
yesFlag string yesFlag bool
wantErr bool wantErr bool
wantDestroyCalled bool wantDestroyCalled bool
}{ }{
"cluster running admin conf": { "cluster running admin conf": {
fh: newFsWithAdminConf(), fh: newFsWithAdminConf(),
iamDestroyer: &stubIAMDestroyer{}, iamDestroyer: &stubIAMDestroyer{},
yesFlag: "false", yesFlag: false,
wantErr: true, wantErr: true,
}, },
"cluster running cluster state": { "cluster running cluster state": {
fh: newFsWithStateFile(), fh: newFsWithStateFile(),
iamDestroyer: &stubIAMDestroyer{}, iamDestroyer: &stubIAMDestroyer{},
yesFlag: "false", yesFlag: false,
wantErr: true, wantErr: true,
}, },
"file missing abort": { "file missing abort": {
fh: newFsMissing(), fh: newFsMissing(),
stdin: "n\n", stdin: "n\n",
yesFlag: "false", yesFlag: false,
iamDestroyer: &stubIAMDestroyer{}, iamDestroyer: &stubIAMDestroyer{},
}, },
"file missing": { "file missing": {
fh: newFsMissing(), fh: newFsMissing(),
stdin: "y\n", stdin: "y\n",
yesFlag: "false", yesFlag: false,
iamDestroyer: &stubIAMDestroyer{}, iamDestroyer: &stubIAMDestroyer{},
wantDestroyCalled: true, wantDestroyCalled: true,
}, },
"file exists abort": { "file exists abort": {
fh: newFsExists(), fh: newFsExists(),
stdin: "n\n", stdin: "n\n",
yesFlag: "false", yesFlag: false,
iamDestroyer: &stubIAMDestroyer{}, iamDestroyer: &stubIAMDestroyer{},
}, },
"error destroying user": { "error destroying user": {
fh: newFsMissing(), fh: newFsMissing(),
stdin: "y\n", stdin: "y\n",
yesFlag: "false", yesFlag: false,
iamDestroyer: &stubIAMDestroyer{destroyErr: someError}, iamDestroyer: &stubIAMDestroyer{destroyErr: someError},
wantErr: true, wantErr: true,
wantDestroyCalled: true, wantDestroyCalled: true,
}, },
"gcp delete error": { "gcp delete error": {
fh: newFsExists(), fh: newFsExists(),
yesFlag: "true", yesFlag: true,
iamDestroyer: &stubIAMDestroyer{getTfStateKeyErr: someError}, iamDestroyer: &stubIAMDestroyer{getTfStateKeyErr: someError},
wantErr: true, wantErr: true,
}, },
@ -106,13 +106,9 @@ func TestIAMDestroy(t *testing.T) {
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually c := &destroyCmd{log: logger.NewTest(t), flags: iamDestroyFlags{
cmd.Flags().String("tf-log", "NONE", "") yes: tc.yesFlag,
cmd.Flags().String("workspace", "", "") }}
assert.NoError(cmd.Flags().Set("yes", tc.yesFlag))
c := &destroyCmd{log: logger.NewTest(t)}
err := c.iamDestroy(cmd, &nopSpinner{}, tc.iamDestroyer, tc.fh) err := c.iamDestroy(cmd, &nopSpinner{}, tc.iamDestroyer, tc.fh)

View file

@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
func newIAMUpgradeCmd() *cobra.Command { func newIAMUpgradeCmd() *cobra.Command {
@ -46,17 +47,32 @@ func newIAMUpgradeApplyCmd() *cobra.Command {
return cmd return cmd
} }
type iamUpgradeApplyFlags struct {
rootFlags
yes bool
}
func (f *iamUpgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
type iamUpgradeApplyCmd struct { type iamUpgradeApplyCmd struct {
fileHandler file.Handler fileHandler file.Handler
log debugLog log debugLog
configFetcher attestationconfigapi.Fetcher configFetcher attestationconfigapi.Fetcher
flags iamUpgradeApplyFlags
} }
func runIAMUpgradeApply(cmd *cobra.Command, _ []string) error { func runIAMUpgradeApply(cmd *cobra.Command, _ []string) error {
force, err := cmd.Flags().GetBool("force")
if err != nil {
return fmt.Errorf("parsing force argument: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
upgradeID := generateUpgradeID(upgradeCmdKindIAM) upgradeID := generateUpgradeID(upgradeCmdKindIAM)
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID) upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
@ -77,22 +93,20 @@ func runIAMUpgradeApply(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("setting up logger: %w", err) return fmt.Errorf("setting up logger: %w", err)
} }
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
i := iamUpgradeApplyCmd{ i := iamUpgradeApplyCmd{
fileHandler: fileHandler, fileHandler: fileHandler,
log: log, log: log,
configFetcher: configFetcher, configFetcher: configFetcher,
} }
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
return i.iamUpgradeApply(cmd, iamMigrateCmd, upgradeDir, force, yes) return i.iamUpgradeApply(cmd, iamMigrateCmd, upgradeDir)
} }
func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamUpgrader, upgradeDir string, force, yes bool) error { func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamUpgrader, upgradeDir string) error {
conf, err := config.New(i.fileHandler, constants.ConfigFilename, i.configFetcher, force) conf, err := config.New(i.fileHandler, constants.ConfigFilename, i.configFetcher, i.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -109,14 +123,14 @@ func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamU
if err != nil { if err != nil {
return fmt.Errorf("planning terraform migrations: %w", err) return fmt.Errorf("planning terraform migrations: %w", err)
} }
if !hasDiff && !force { if !hasDiff && !i.flags.force {
cmd.Println("No IAM migrations necessary.") cmd.Println("No IAM migrations necessary.")
return nil return nil
} }
// If there are any Terraform migrations to apply, ask for confirmation // If there are any Terraform migrations to apply, ask for confirmation
cmd.Println("The IAM upgrade requires a migration by applying an updated Terraform template. Please manually review the suggested changes.") cmd.Println("The IAM upgrade requires a migration by applying an updated Terraform template. Please manually review the suggested changes.")
if !yes { if !i.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the IAM upgrade?") ok, err := askToConfirm(cmd, "Do you want to apply the IAM upgrade?")
if err != nil { if err != nil {
return fmt.Errorf("asking for confirmation: %w", err) return fmt.Errorf("asking for confirmation: %w", err)

View file

@ -132,9 +132,12 @@ func TestIamUpgradeApply(t *testing.T) {
fileHandler: tc.fh, fileHandler: tc.fh,
log: logger.NewTest(t), log: logger.NewTest(t),
configFetcher: tc.configFetcher, configFetcher: tc.configFetcher,
flags: iamUpgradeApplyFlags{
yes: tc.yesFlag,
},
} }
err := iamUpgradeApplyCmd.iamUpgradeApply(cmd, tc.iamUpgrader, "", false, tc.yesFlag) err := iamUpgradeApplyCmd.iamUpgradeApply(cmd, tc.iamUpgrader, "")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

View file

@ -28,6 +28,7 @@ import (
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc" "google.golang.org/grpc"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd"
@ -36,7 +37,6 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/helm" "github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
@ -69,12 +69,45 @@ func NewInitCmd() *cobra.Command {
return cmd return cmd
} }
// initFlags are flags used by the init command.
type initFlags struct {
rootFlags
conformance bool
helmWaitMode helm.WaitMode
mergeConfigs bool
}
func (f *initFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
type initCmd struct { type initCmd struct {
log debugLog log debugLog
merger configMerger merger configMerger
spinner spinnerInterf spinner spinnerInterf
fileHandler file.Handler fileHandler file.Handler
pf pathprefix.PathPrefixer flags initFlags
} }
func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd { func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd {
@ -109,6 +142,11 @@ func runInitialize(cmd *cobra.Command, _ []string) error {
cmd.SetContext(ctx) cmd.SetContext(ctx)
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log) i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
i.log.Debugf("Using flags: %+v", i.flags)
fetcher := attestationconfigapi.NewFetcher() fetcher := attestationconfigapi.NewFetcher()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) { newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log) return kubecmd.New(w, kubeConfig, fileHandler, log)
@ -127,13 +165,8 @@ func (i *initCmd) initialize(
newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error), newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error),
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error), newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error),
) error { ) error {
flags, err := i.evalFlagArgs(cmd) i.log.Debugf("Loading configuration file from %q", i.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
if err != nil { conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, i.flags.force)
return err
}
i.log.Debugf("Using flags: %+v", flags)
i.log.Debugf("Loading configuration file from %q", i.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -146,7 +179,7 @@ func (i *initCmd) initialize(
if err != nil { if err != nil {
return err return err
} }
if !flags.force { if !i.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil { if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err return err
} }
@ -183,7 +216,7 @@ func (i *initCmd) initialize(
return fmt.Errorf("creating new validator: %w", err) return fmt.Errorf("creating new validator: %w", err)
} }
i.log.Debugf("Created a new validator") i.log.Debugf("Created a new validator")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.pf, i.log, i.fileHandler) serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.flags.pathPrefixer, i.log, i.fileHandler)
if err != nil { if err != nil {
return err return err
} }
@ -211,7 +244,7 @@ func (i *initCmd) initialize(
MeasurementSalt: measurementSalt, MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion, KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(), KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: flags.conformance, ConformanceMode: i.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret, InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name, ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs, ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
@ -228,7 +261,7 @@ func (i *initCmd) initialize(
if nonRetriable.logCollectionErr != nil { if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr) cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else { } else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.pf.PrefixPrintablePath(constants.ErrorLog)) cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
} }
} }
return err return err
@ -236,7 +269,7 @@ func (i *initCmd) initialize(
i.log.Debugf("Initialization request succeeded") i.log.Debugf("Initialization request succeeded")
bufferedOutput := &bytes.Buffer{} bufferedOutput := &bytes.Buffer{}
if err := i.writeOutput(stateFile, resp, flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil { if err := i.writeOutput(stateFile, resp, i.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return err return err
} }
@ -250,9 +283,9 @@ func (i *initCmd) initialize(
i.spinner.Start("Installing Kubernetes components ", false) i.spinner.Start("Installing Kubernetes components ", false)
options := helm.Options{ options := helm.Options{
Force: flags.force, Force: i.flags.force,
Conformance: flags.conformance, Conformance: i.flags.conformance,
HelmWaitMode: flags.helmWaitMode, HelmWaitMode: i.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive, AllowDestructive: helm.DenyDestructive,
} }
helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log) helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log)
@ -457,7 +490,7 @@ func (i *initCmd) writeOutput(
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID) writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", i.pf.PrefixPrintablePath(constants.AdminConfFilename)) writeRow(tw, "Kubernetes configuration", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush() tw.Flush()
fmt.Fprintln(wr) fmt.Fprintln(wr)
@ -485,7 +518,7 @@ func (i *initCmd) writeOutput(
if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil { if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err) return fmt.Errorf("writing kubeconfig: %w", err)
} }
i.log.Debugf("Kubeconfig written to %s", i.pf.PrefixPrintablePath(constants.AdminConfFilename)) i.log.Debugf("Kubeconfig written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig { if mergeConfig {
if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil { if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil {
@ -500,7 +533,7 @@ func (i *initCmd) writeOutput(
return fmt.Errorf("writing Constellation state file: %w", err) return fmt.Errorf("writing Constellation state file: %w", err)
} }
i.log.Debugf("Constellation state file written to %s", i.pf.PrefixPrintablePath(constants.StateFilename)) i.log.Debugf("Constellation state file written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig { if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:") fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
@ -528,57 +561,6 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
fmt.Fprint(wr, col1, "\t", col2, "\n") fmt.Fprint(wr, col1, "\t", col2, "\n")
} }
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func (i *initCmd) evalFlagArgs(cmd *cobra.Command) (initFlags, error) {
conformance, err := cmd.Flags().GetBool("conformance")
if err != nil {
return initFlags{}, fmt.Errorf("parsing conformance flag: %w", err)
}
i.log.Debugf("Conformance flag is %t", conformance)
skipHelmWait, err := cmd.Flags().GetBool("skip-helm-wait")
if err != nil {
return initFlags{}, fmt.Errorf("parsing skip-helm-wait flag: %w", err)
}
helmWaitMode := helm.WaitModeAtomic
if skipHelmWait {
helmWaitMode = helm.WaitModeNone
}
i.log.Debugf("Helm wait flag is %t", skipHelmWait)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return initFlags{}, fmt.Errorf("parsing config path flag: %w", err)
}
i.pf = pathprefix.New(workDir)
mergeConfigs, err := cmd.Flags().GetBool("merge-kubeconfig")
if err != nil {
return initFlags{}, fmt.Errorf("parsing merge-kubeconfig flag: %w", err)
}
i.log.Debugf("Merge kubeconfig flag is %t", mergeConfigs)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return initFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
i.log.Debugf("force flag is %t", force)
return initFlags{
conformance: conformance,
helmWaitMode: helmWaitMode,
force: force,
mergeConfigs: mergeConfigs,
}, nil
}
// initFlags are the resulting values of flag preprocessing.
type initFlags struct {
conformance bool
helmWaitMode helm.WaitMode
force bool
mergeConfigs bool
}
// generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret. // generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) { func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk // No file given, generate a new secret, and save it to disk
@ -599,7 +581,7 @@ func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, e
if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil { if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err return uri.MasterSecret{}, err
} }
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.pf.PrefixPrintablePath(constants.MasterSecretFilename)) fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil return secret, nil
} }

View file

@ -223,10 +223,6 @@ func TestInitialize(t *testing.T) {
var errOut bytes.Buffer var errOut bytes.Buffer
cmd.SetErr(&errOut) cmd.SetErr(&errOut)
// Flags
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
// File system preparation // File system preparation
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
@ -249,6 +245,8 @@ func TestInitialize(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t)) i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t))
i.flags.force = true
err := i.initialize( err := i.initialize(
cmd, cmd,
newDialer, newDialer,
@ -442,15 +440,15 @@ func TestWriteOutput(t *testing.T) {
require.NoError(afs.Remove(constants.AdminConfFilename)) require.NoError(afs.Remove(constants.AdminConfFilename))
// test custom workspace // test custom workspace
i.pf = pathprefix.New("/some/path") i.flags.pathPrefixer = pathprefix.New("/some/path")
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt) err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), clusterID) assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), i.pf.PrefixPrintablePath(constants.AdminConfFilename)) assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
out.Reset() out.Reset()
// File is written to current working dir, we simply pass the workspace for generating readable user output // File is written to current working dir, we simply pass the workspace for generating readable user output
require.NoError(afs.Remove(constants.AdminConfFilename)) require.NoError(afs.Remove(constants.AdminConfFilename))
i.pf = pathprefix.PathPrefixer{} i.flags.pathPrefixer = pathprefix.PathPrefixer{}
// test config merging // test config merging
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt) err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)

View file

@ -14,13 +14,11 @@ import (
"net" "net"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset" "github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/cli/internal/helm" "github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt" "github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -51,6 +49,8 @@ func newMiniUpCmd() *cobra.Command {
type miniUpCmd struct { type miniUpCmd struct {
log debugLog log debugLog
configFetcher attestationconfigapi.Fetcher configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler
flags rootFlags
} }
func runUp(cmd *cobra.Command, _ []string) error { func runUp(cmd *cobra.Command, _ []string) error {
@ -66,7 +66,14 @@ func runUp(cmd *cobra.Command, _ []string) error {
defer spinner.Stop() defer spinner.Stop()
creator := cloudcmd.NewCreator(spinner) creator := cloudcmd.NewCreator(spinner)
m := &miniUpCmd{log: log, configFetcher: attestationconfigapi.NewFetcher()} m := &miniUpCmd{
log: log,
configFetcher: attestationconfigapi.NewFetcher(),
fileHandler: file.NewHandler(afero.NewOsFs()),
}
if err := m.flags.parse(cmd.Flags()); err != nil {
return err
}
return m.up(cmd, creator, spinner) return m.up(cmd, creator, spinner)
} }
@ -75,22 +82,15 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
return fmt.Errorf("system requirements not met: %w", err) return fmt.Errorf("system requirements not met: %w", err)
} }
flags, err := m.parseUpFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs())
// create config if not passed as flag and set default values // create config if not passed as flag and set default values
config, err := m.prepareConfig(cmd, fileHandler, flags) config, err := m.prepareConfig(cmd)
if err != nil { if err != nil {
return fmt.Errorf("preparing config: %w", err) return fmt.Errorf("preparing config: %w", err)
} }
// create cluster // create cluster
spinner.Start("Creating cluster in QEMU ", false) spinner.Start("Creating cluster in QEMU ", false)
err = m.createMiniCluster(cmd.Context(), fileHandler, creator, config, flags) err = m.createMiniCluster(cmd.Context(), creator, config)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
return fmt.Errorf("creating cluster: %w", err) return fmt.Errorf("creating cluster: %w", err)
@ -105,7 +105,7 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
cmd.Printf("\tvirsh -c %s\n\n", connectURI) cmd.Printf("\tvirsh -c %s\n\n", connectURI)
// initialize cluster // initialize cluster
if err := m.initializeMiniCluster(cmd, fileHandler, spinner); err != nil { if err := m.initializeMiniCluster(cmd, spinner); err != nil {
return fmt.Errorf("initializing cluster: %w", err) return fmt.Errorf("initializing cluster: %w", err)
} }
m.log.Debugf("Initialized cluster") m.log.Debugf("Initialized cluster")
@ -113,8 +113,8 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
} }
// prepareConfig reads a given config, or creates a new minimal QEMU config. // prepareConfig reads a given config, or creates a new minimal QEMU config.
func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler, flags upFlags) (*config.Config, error) { func (m *miniUpCmd) prepareConfig(cmd *cobra.Command) (*config.Config, error) {
_, err := fileHandler.Stat(constants.ConfigFilename) _, err := m.fileHandler.Stat(constants.ConfigFilename)
if err == nil { if err == nil {
// config already exists, prompt user if they want to use this file // config already exists, prompt user if they want to use this file
cmd.PrintErrln("A config file already exists in the configured workspace.") cmd.PrintErrln("A config file already exists in the configured workspace.")
@ -123,7 +123,7 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler,
return nil, err return nil, err
} }
if ok { if ok {
return m.prepareExistingConfig(cmd, fileHandler, flags) return m.prepareExistingConfig(cmd)
} }
// user declined to reuse config file, prompt if they want to overwrite it // user declined to reuse config file, prompt if they want to overwrite it
@ -146,11 +146,11 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler,
} }
m.log.Debugf("Prepared configuration") m.log.Debugf("Prepared configuration")
return config, fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptOverwrite) return config, m.fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptOverwrite)
} }
func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command, fileHandler file.Handler, flags upFlags) (*config.Config, error) { func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command) (*config.Config, error) {
conf, err := config.New(fileHandler, constants.ConfigFilename, m.configFetcher, flags.force) conf, err := config.New(m.fileHandler, constants.ConfigFilename, m.configFetcher, m.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -165,13 +165,13 @@ func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command, fileHandler file.H
} }
// createMiniCluster creates a new cluster using the given config. // createMiniCluster creates a new cluster using the given config.
func (m *miniUpCmd) createMiniCluster(ctx context.Context, fileHandler file.Handler, creator cloudCreator, config *config.Config, flags upFlags) error { func (m *miniUpCmd) createMiniCluster(ctx context.Context, creator cloudCreator, config *config.Config) error {
m.log.Debugf("Creating mini cluster") m.log.Debugf("Creating mini cluster")
opts := cloudcmd.CreateOptions{ opts := cloudcmd.CreateOptions{
Provider: cloudprovider.QEMU, Provider: cloudprovider.QEMU,
Config: config, Config: config,
TFWorkspace: constants.TerraformWorkingDir, TFWorkspace: constants.TerraformWorkingDir,
TFLogLevel: flags.tfLogLevel, TFLogLevel: m.flags.tfLogLevel,
} }
infraState, err := creator.Create(ctx, opts) infraState, err := creator.Create(ctx, opts)
if err != nil { if err != nil {
@ -184,11 +184,11 @@ func (m *miniUpCmd) createMiniCluster(ctx context.Context, fileHandler file.Hand
SetInfrastructure(infraState) SetInfrastructure(infraState)
m.log.Debugf("Cluster state file contains %v", stateFile) m.log.Debugf("Cluster state file contains %v", stateFile)
return stateFile.WriteToFile(fileHandler, constants.StateFilename) return stateFile.WriteToFile(m.fileHandler, constants.StateFilename)
} }
// initializeMiniCluster initializes a QEMU cluster. // initializeMiniCluster initializes a QEMU cluster.
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.Handler, spinner spinnerInterf) (retErr error) { func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInterf) (retErr error) {
m.log.Debugf("Initializing mini cluster") m.log.Debugf("Initializing mini cluster")
// clean up cluster resources if initialization fails // clean up cluster resources if initialization fails
defer func() { defer func() {
@ -214,12 +214,17 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.H
defer log.Sync() defer log.Sync()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) { newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log) return kubecmd.New(w, kubeConfig, m.fileHandler, log)
} }
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) { newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log) return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available } // need to defer helm client instantiation until kubeconfig is available
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
i := newInitCmd(m.fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher, if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher,
newAttestationApplier, newHelmClient); err != nil { newAttestationApplier, newHelmClient); err != nil {
return err return err
@ -227,37 +232,3 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.H
m.log.Debugf("Initialized mini cluster") m.log.Debugf("Initialized mini cluster")
return nil return nil
} }
type upFlags struct {
force bool
tfLogLevel terraform.LogLevel
}
func (m *miniUpCmd) parseUpFlags(cmd *cobra.Command) (upFlags, error) {
m.log.Debugf("Preparing configuration")
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return upFlags{}, fmt.Errorf("parsing config string: %w", err)
}
m.log.Debugf("Workspace set to %q", workDir)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upFlags{}, fmt.Errorf("parsing force bool: %w", err)
}
m.log.Debugf("force flag is %q", force)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
m.log.Debugf("Terraform logs will be written into %s at level %s", pathprefix.New(workDir).PrefixPrintablePath(constants.TerraformLogFile), logLevel.String())
return upFlags{
force: force,
tfLogLevel: logLevel,
}, nil
}

View file

@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
@ -31,6 +30,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/retry" "github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
// NewRecoverCmd returns a new cobra.Command for the recover command. // NewRecoverCmd returns a new cobra.Command for the recover command.
@ -47,10 +47,28 @@ func NewRecoverCmd() *cobra.Command {
return cmd return cmd
} }
type recoverFlags struct {
rootFlags
endpoint string
}
func (f *recoverFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
endpoint, err := flags.GetString("endpoint")
if err != nil {
return fmt.Errorf("getting 'endpoint' flag: %w", err)
}
f.endpoint = endpoint
return nil
}
type recoverCmd struct { type recoverCmd struct {
log debugLog log debugLog
configFetcher attestationconfigapi.Fetcher configFetcher attestationconfigapi.Fetcher
pf pathprefix.PathPrefixer flags recoverFlags
} }
func runRecover(cmd *cobra.Command, _ []string) error { func runRecover(cmd *cobra.Command, _ []string) error {
@ -64,6 +82,10 @@ func runRecover(cmd *cobra.Command, _ []string) error {
return dialer.New(nil, validator, &net.Dialer{}) return dialer.New(nil, validator, &net.Dialer{})
} }
r := &recoverCmd{log: log, configFetcher: attestationconfigapi.NewFetcher()} r := &recoverCmd{log: log, configFetcher: attestationconfigapi.NewFetcher()}
if err := r.flags.parse(cmd.Flags()); err != nil {
return err
}
r.log.Debugf("Using flags: %+v", r.flags)
return r.recover(cmd, fileHandler, 5*time.Second, &recoverDoer{log: r.log}, newDialer) return r.recover(cmd, fileHandler, 5*time.Second, &recoverDoer{log: r.log}, newDialer)
} }
@ -71,20 +93,14 @@ func (r *recoverCmd) recover(
cmd *cobra.Command, fileHandler file.Handler, interval time.Duration, cmd *cobra.Command, fileHandler file.Handler, interval time.Duration,
doer recoverDoerInterface, newDialer func(validator atls.Validator) *dialer.Dialer, doer recoverDoerInterface, newDialer func(validator atls.Validator) *dialer.Dialer,
) error { ) error {
flags, err := r.parseRecoverFlags(cmd, fileHandler)
if err != nil {
return err
}
r.log.Debugf("Using flags: %+v", flags)
var masterSecret uri.MasterSecret var masterSecret uri.MasterSecret
r.log.Debugf("Loading master secret file from %s", r.pf.PrefixPrintablePath(constants.MasterSecretFilename)) r.log.Debugf("Loading master secret file from %s", r.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
if err := fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret); err != nil { if err := fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret); err != nil {
return err return err
} }
r.log.Debugf("Loading configuration file from %q", r.pf.PrefixPrintablePath(constants.ConfigFilename)) r.log.Debugf("Loading configuration file from %q", r.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, r.configFetcher, flags.force) conf, err := config.New(fileHandler, constants.ConfigFilename, r.configFetcher, r.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -99,15 +115,26 @@ func (r *recoverCmd) recover(
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
} }
conf.UpdateMAAURL(flags.maaURL) stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
endpoint, err := r.parseEndpoint(stateFile)
if err != nil {
return err
}
if stateFile.Infrastructure.Azure != nil {
conf.UpdateMAAURL(stateFile.Infrastructure.Azure.AttestationURL)
}
r.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant()) r.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), r.log) validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), r.log)
if err != nil { if err != nil {
return fmt.Errorf("creating new validator: %w", err) return fmt.Errorf("creating new validator: %w", err)
} }
r.log.Debugf("Created a new validator") r.log.Debugf("Created a new validator")
doer.setDialer(newDialer(validator), flags.endpoint) doer.setDialer(newDialer(validator), endpoint)
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint) r.log.Debugf("Set dialer for endpoint %s", endpoint)
doer.setURIs(masterSecret.EncodeToURI(), uri.NoStoreURI) doer.setURIs(masterSecret.EncodeToURI(), uri.NoStoreURI)
r.log.Debugf("Set secrets") r.log.Debugf("Set secrets")
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil { if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
@ -160,6 +187,18 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
return err return err
} }
func (r *recoverCmd) parseEndpoint(state *state.State) (string, error) {
endpoint := r.flags.endpoint
if endpoint == "" {
endpoint = state.Infrastructure.ClusterEndpoint
}
endpoint, err := addPortIfMissing(endpoint, constants.RecoveryPort)
if err != nil {
return "", fmt.Errorf("validating cluster endpoint: %w", err)
}
return endpoint, nil
}
type recoverDoerInterface interface { type recoverDoerInterface interface {
Do(ctx context.Context) error Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string) setDialer(dialer grpcDialer, endpoint string)
@ -209,55 +248,3 @@ func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
d.kmsURI = kmsURI d.kmsURI = kmsURI
d.storageURI = storageURI d.storageURI = storageURI
} }
type recoverFlags struct {
endpoint string
maaURL string
force bool
}
func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing config path argument: %w", err)
}
r.log.Debugf("Workspace set to %q", workDir)
r.pf = pathprefix.New(workDir)
endpoint, err := cmd.Flags().GetString("endpoint")
r.log.Debugf("Endpoint flag is %s", endpoint)
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
var attestationURL string
stateFile := state.New()
if endpoint == "" {
stateFile, err = state.ReadFromFile(fileHandler, constants.StateFilename)
if err != nil {
return recoverFlags{}, fmt.Errorf("reading state file: %w", err)
}
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
endpoint, err = addPortIfMissing(endpoint, constants.RecoveryPort)
if err != nil {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
}
r.log.Debugf("Endpoint value after parsing is %s", endpoint)
if stateFile.Infrastructure.Azure != nil {
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
}
return recoverFlags{
endpoint: endpoint,
maaURL: attestationURL,
force: force,
}, nil
}

View file

@ -140,12 +140,9 @@ func TestRecover(t *testing.T) {
cmd := NewRecoverCmd() cmd := NewRecoverCmd()
cmd.SetContext(context.Background()) cmd.SetContext(context.Background())
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
out := &bytes.Buffer{} out := &bytes.Buffer{}
cmd.SetOut(out) cmd.SetOut(out)
cmd.SetErr(out) cmd.SetErr(out)
require.NoError(cmd.Flags().Set("endpoint", tc.endpoint))
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
@ -156,13 +153,25 @@ func TestRecover(t *testing.T) {
} }
require.NoError(fileHandler.WriteJSON( require.NoError(fileHandler.WriteJSON(
"constellation-mastersecret.json", constants.MasterSecretFilename,
uri.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt}, uri.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
file.OptNone, file.OptNone,
)) ))
require.NoError(fileHandler.WriteYAML(
constants.StateFilename,
state.New(),
file.OptNone,
))
newDialer := func(atls.Validator) *dialer.Dialer { return nil } newDialer := func(atls.Validator) *dialer.Dialer { return nil }
r := &recoverCmd{log: logger.NewTest(t), configFetcher: stubAttestationFetcher{}} r := &recoverCmd{
log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{},
flags: recoverFlags{
rootFlags: rootFlags{force: true},
endpoint: tc.endpoint,
},
}
err := r.recover(cmd, fileHandler, time.Millisecond, tc.doer, newDialer) err := r.recover(cmd, fileHandler, time.Millisecond, tc.doer, newDialer)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -183,68 +192,6 @@ func TestRecover(t *testing.T) {
} }
} }
func TestParseRecoverFlags(t *testing.T) {
testCases := map[string]struct {
args []string
wantFlags recoverFlags
writeStateFile bool
wantErr bool
}{
"no flags": {
wantFlags: recoverFlags{
endpoint: "192.0.2.42:9999",
},
writeStateFile: true,
},
"no flags, no ID file": {
wantFlags: recoverFlags{
endpoint: "192.0.2.42:9999",
},
wantErr: true,
},
"invalid endpoint": {
args: []string{"-e", "192.0.2.42:2:2"},
wantErr: true,
},
"all args set": {
args: []string{"-e", "192.0.2.42:2", "--workspace", "./constellation-workspace"},
wantFlags: recoverFlags{
endpoint: "192.0.2.42:2",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := NewRecoverCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", false, "") // register persistent flag manually
require.NoError(cmd.ParseFlags(tc.args))
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.writeStateFile {
require.NoError(
state.New().
SetInfrastructure(state.Infrastructure{ClusterEndpoint: "192.0.2.42"}).
WriteToFile(fileHandler, constants.StateFilename),
)
}
r := &recoverCmd{log: logger.NewTest(t)}
flags, err := r.parseRecoverFlags(cmd, fileHandler)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestDoRecovery(t *testing.T) { func TestDoRecovery(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
recoveryServer *stubRecoveryServer recoveryServer *stubRecoveryServer

View file

@ -45,11 +45,6 @@ func runStatus(cmd *cobra.Command, _ []string) error {
} }
defer log.Sync() defer log.Sync()
flags, err := parseStatusFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
helmClient, err := helm.NewReleaseVersionClient(constants.AdminConfFilename, log) helmClient, err := helm.NewReleaseVersionClient(constants.AdminConfFilename, log)
@ -61,55 +56,61 @@ func runStatus(cmd *cobra.Command, _ []string) error {
} }
fetcher := attestationconfigapi.NewFetcher() fetcher := attestationconfigapi.NewFetcher()
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
variant := conf.GetAttestationConfig().GetVariant()
kubeClient, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log) kubeClient, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log)
if err != nil { if err != nil {
return fmt.Errorf("setting up kubernetes client: %w", err) return fmt.Errorf("setting up kubernetes client: %w", err)
} }
output, err := status(cmd.Context(), helmVersionGetter, kubeClient, variant) s := statusCmd{log: log, fileHandler: fileHandler}
if err != nil { if err := s.flags.parse(cmd.Flags()); err != nil {
return fmt.Errorf("getting status: %w", err) return err
} }
return s.status(cmd, helmVersionGetter, kubeClient, fetcher)
}
cmd.Print(output) type statusCmd struct {
return nil log debugLog
fileHandler file.Handler
flags rootFlags
} }
// status queries the cluster for the relevant status information and returns the output string. // status queries the cluster for the relevant status information and returns the output string.
func status(ctx context.Context, getHelmVersions func() (fmt.Stringer, error), kubeClient kubeCmd, attestVariant variant.Variant, func (s *statusCmd) status(
) (string, error) { cmd *cobra.Command, getHelmVersions func() (fmt.Stringer, error),
nodeVersion, err := kubeClient.GetConstellationVersion(ctx) kubeClient kubeCmd, fetcher attestationconfigapi.Fetcher,
if err != nil { ) error {
return "", fmt.Errorf("getting constellation version: %w", err) conf, err := config.New(s.fileHandler, constants.ConfigFilename, fetcher, s.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
} }
attestationConfig, err := kubeClient.GetClusterAttestationConfig(ctx, attestVariant) nodeVersion, err := kubeClient.GetConstellationVersion(cmd.Context())
if err != nil { if err != nil {
return "", fmt.Errorf("getting attestation config: %w", err) return fmt.Errorf("getting constellation version: %w", err)
}
attestationConfig, err := kubeClient.GetClusterAttestationConfig(cmd.Context(), conf.GetAttestationConfig().GetVariant())
if err != nil {
return fmt.Errorf("getting attestation config: %w", err)
} }
prettyYAML, err := yaml.Marshal(attestationConfig) prettyYAML, err := yaml.Marshal(attestationConfig)
if err != nil { if err != nil {
return "", fmt.Errorf("marshalling attestation config: %w", err) return fmt.Errorf("marshalling attestation config: %w", err)
} }
serviceVersions, err := getHelmVersions() serviceVersions, err := getHelmVersions()
if err != nil { if err != nil {
return "", fmt.Errorf("getting service versions: %w", err) return fmt.Errorf("getting service versions: %w", err)
} }
status, err := kubeClient.ClusterStatus(ctx) status, err := kubeClient.ClusterStatus(cmd.Context())
if err != nil { if err != nil {
return "", fmt.Errorf("getting cluster status: %w", err) return fmt.Errorf("getting cluster status: %w", err)
} }
return statusOutput(nodeVersion, serviceVersions, status, string(prettyYAML)), nil cmd.Print(statusOutput(nodeVersion, serviceVersions, status, string(prettyYAML)))
return nil
} }
// statusOutput creates the status cmd output string by formatting the received information. // statusOutput creates the status cmd output string by formatting the received information.
@ -167,26 +168,6 @@ func targetVersionsString(target kubecmd.NodeVersion) string {
return builder.String() return builder.String()
} }
type statusFlags struct {
workspace string
force bool
}
func parseStatusFlags(cmd *cobra.Command) (statusFlags, error) {
workspace, err := cmd.Flags().GetString("workspace")
if err != nil {
return statusFlags{}, fmt.Errorf("getting config flag: %w", err)
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return statusFlags{}, fmt.Errorf("getting config flag: %w", err)
}
return statusFlags{
workspace: workspace,
force: force,
}, nil
}
type kubeCmd interface { type kubeCmd interface {
ClusterStatus(ctx context.Context) (map[string]kubecmd.NodeStatus, error) ClusterStatus(ctx context.Context) (map[string]kubecmd.NodeStatus, error)
GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error) GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error)

View file

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd package cmd
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"testing" "testing"
@ -14,8 +15,12 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1" updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
@ -63,7 +68,6 @@ func TestStatus(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
kubeClient stubKubeClient kubeClient stubKubeClient
attestVariant variant.Variant
expectedOutput string expectedOutput string
wantErr bool wantErr bool
}{ }{
@ -104,7 +108,6 @@ func TestStatus(t *testing.T) {
}, },
}, },
}, },
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput, expectedOutput: successOutput,
}, },
"one of two nodes not upgraded": { "one of two nodes not upgraded": {
@ -157,7 +160,6 @@ func TestStatus(t *testing.T) {
}, },
}, },
}, },
attestVariant: variant.QEMUVTPM{},
expectedOutput: inProgressOutput, expectedOutput: inProgressOutput,
}, },
"error getting node status": { "error getting node status": {
@ -183,7 +185,6 @@ func TestStatus(t *testing.T) {
}, },
}, },
}, },
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput, expectedOutput: successOutput,
wantErr: true, wantErr: true,
}, },
@ -211,7 +212,6 @@ func TestStatus(t *testing.T) {
}, },
}, },
}, },
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput, expectedOutput: successOutput,
wantErr: true, wantErr: true,
}, },
@ -248,7 +248,6 @@ func TestStatus(t *testing.T) {
}), }),
attestationErr: assert.AnError, attestationErr: assert.AnError,
}, },
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput, expectedOutput: successOutput,
wantErr: true, wantErr: true,
}, },
@ -259,19 +258,31 @@ func TestStatus(t *testing.T) {
require := require.New(t) require := require.New(t)
assert := assert.New(t) assert := assert.New(t)
variant := variant.AWSNitroTPM{} cmd := NewStatusCmd()
output, err := status( var out bytes.Buffer
context.Background(), cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cfg, err := createConfigWithAttestationVariant(cloudprovider.QEMU, "", variant.QEMUVTPM{})
require.NoError(err)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
s := statusCmd{fileHandler: fileHandler}
err = s.status(
cmd,
stubGetVersions(versionsOutput), stubGetVersions(versionsOutput),
tc.kubeClient, tc.kubeClient,
variant, stubAttestationFetcher{},
) )
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
} }
require.NoError(err) require.NoError(err)
assert.Equal(tc.expectedOutput, output) assert.Equal(tc.expectedOutput, out.String())
}) })
} }
} }

View file

@ -13,10 +13,9 @@ import (
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
) )
@ -35,9 +34,26 @@ func NewTerminateCmd() *cobra.Command {
return cmd return cmd
} }
type terminateFlags struct {
rootFlags
yes bool
}
func (f *terminateFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
// runTerminate runs the terminate command. // runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, _ []string) error { func runTerminate(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
spinner, err := newSpinnerOrStderr(cmd) spinner, err := newSpinnerOrStderr(cmd)
if err != nil { if err != nil {
return fmt.Errorf("creating spinner: %w", err) return fmt.Errorf("creating spinner: %w", err)
@ -45,18 +61,27 @@ func runTerminate(cmd *cobra.Command, _ []string) error {
defer spinner.Stop() defer spinner.Stop()
terminator := cloudcmd.NewTerminator() terminator := cloudcmd.NewTerminator()
return terminate(cmd, terminator, fileHandler, spinner) logger, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
t := &terminateCmd{log: logger, fileHandler: file.NewHandler(afero.NewOsFs())}
if err := t.flags.parse(cmd.Flags()); err != nil {
return err
}
return t.terminate(cmd, terminator, spinner)
} }
func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler, spinner spinnerInterf, type terminateCmd struct {
) error { log debugLog
flags, err := parseTerminateFlags(cmd) fileHandler file.Handler
if err != nil { flags terminateFlags
return fmt.Errorf("parsing flags: %w", err) }
}
pf := pathprefix.New(flags.workspace)
if !flags.yes { func (t *terminateCmd) terminate(cmd *cobra.Command, terminator cloudTerminator, spinner spinnerInterf) error {
if !t.flags.yes {
cmd.Println("You are about to terminate a Constellation cluster.") cmd.Println("You are about to terminate a Constellation cluster.")
cmd.Println("All of its associated resources will be DESTROYED.") cmd.Println("All of its associated resources will be DESTROYED.")
cmd.Println("This action is irreversible and ALL DATA WILL BE LOST.") cmd.Println("This action is irreversible and ALL DATA WILL BE LOST.")
@ -71,7 +96,7 @@ func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.
} }
spinner.Start("Terminating", false) spinner.Start("Terminating", false)
err = terminator.Terminate(cmd.Context(), constants.TerraformWorkingDir, flags.logLevel) err := terminator.Terminate(cmd.Context(), constants.TerraformWorkingDir, t.flags.tfLogLevel)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
return fmt.Errorf("terminating Constellation cluster: %w", err) return fmt.Errorf("terminating Constellation cluster: %w", err)
@ -80,44 +105,13 @@ func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.
cmd.Println("Your Constellation cluster was terminated successfully.") cmd.Println("Your Constellation cluster was terminated successfully.")
var removeErr error var removeErr error
if err := fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) { if err := t.fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", pf.PrefixPrintablePath(constants.AdminConfFilename))) removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", t.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)))
} }
if err := fileHandler.Remove(constants.StateFilename); err != nil && !errors.Is(err, fs.ErrNotExist) { if err := t.fileHandler.Remove(constants.StateFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", pf.PrefixPrintablePath(constants.StateFilename))) removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", t.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)))
} }
return removeErr return removeErr
} }
type terminateFlags struct {
yes bool
workspace string
logLevel terraform.LogLevel
}
func parseTerminateFlags(cmd *cobra.Command) (terminateFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
workspace, err := cmd.Flags().GetString("workspace")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing workspace string: %w", err)
}
return terminateFlags{
yes: yes,
workspace: workspace,
logLevel: logLevel,
}, nil
}

View file

@ -14,6 +14,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -134,18 +135,17 @@ func TestTerminate(t *testing.T) {
cmd.SetErr(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin)) cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().String("workspace", "", "")
require.NotNil(tc.setupFs) require.NotNil(tc.setupFs)
fileHandler := file.NewHandler(tc.setupFs(require, tc.stateFile)) fileHandler := file.NewHandler(tc.setupFs(require, tc.stateFile))
if tc.yesFlag { tCmd := &terminateCmd{
require.NoError(cmd.Flags().Set("yes", "true")) log: logger.NewTest(t),
fileHandler: fileHandler,
flags: terminateFlags{
yes: tc.yesFlag,
},
} }
err := tCmd.terminate(cmd, tc.terminator, &nopSpinner{})
err := terminate(cmd, tc.terminator, fileHandler, &nopSpinner{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)

View file

@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/helm" "github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
@ -33,6 +32,7 @@ import (
"github.com/rogpeppe/go-internal/diff" "github.com/rogpeppe/go-internal/diff"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
) )
@ -76,12 +76,62 @@ func newUpgradeApplyCmd() *cobra.Command {
return cmd return cmd
} }
func runUpgradeApply(cmd *cobra.Command, _ []string) error { type upgradeApplyFlags struct {
flags, err := parseUpgradeApplyFlags(cmd) rootFlags
if err != nil { yes bool
return fmt.Errorf("parsing flags: %w", err) upgradeTimeout time.Duration
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
func (f *upgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
} }
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
return nil
}
func runUpgradeApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd) log, err := newCLILogger(cmd)
if err != nil { if err != nil {
return fmt.Errorf("creating logger: %w", err) return fmt.Errorf("creating logger: %w", err)
@ -98,13 +148,18 @@ func runUpgradeApply(cmd *cobra.Command, _ []string) error {
configFetcher := attestationconfigapi.NewFetcher() configFetcher := attestationconfigapi.NewFetcher()
var flags upgradeApplyFlags
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
// Set up terraform upgrader // Set up terraform upgrader
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID) upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader( clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(), cmd.Context(),
constants.TerraformWorkingDir, constants.TerraformWorkingDir,
upgradeDir, upgradeDir,
flags.terraformLogLevel, flags.tfLogLevel,
fileHandler, fileHandler,
) )
if err != nil { if err != nil {
@ -122,9 +177,10 @@ func runUpgradeApply(cmd *cobra.Command, _ []string) error {
clusterUpgrader: clusterUpgrader, clusterUpgrader: clusterUpgrader,
configFetcher: configFetcher, configFetcher: configFetcher,
fileHandler: fileHandler, fileHandler: fileHandler,
flags: flags,
log: log, log: log,
} }
return applyCmd.upgradeApply(cmd, upgradeDir, flags) return applyCmd.upgradeApply(cmd, upgradeDir)
} }
type upgradeApplyCmd struct { type upgradeApplyCmd struct {
@ -133,11 +189,12 @@ type upgradeApplyCmd struct {
clusterUpgrader clusterUpgrader clusterUpgrader clusterUpgrader
configFetcher attestationconfigapi.Fetcher configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler fileHandler file.Handler
flags upgradeApplyFlags
log debugLog log debugLog
} }
func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, flags upgradeApplyFlags) error { func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, flags.force) conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, u.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -147,7 +204,7 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
} }
if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) { if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) {
cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.") cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.")
if !flags.yes { if !u.flags.yes {
yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?") yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?")
if err != nil { if err != nil {
return fmt.Errorf("asking for confirmation: %w", err) return fmt.Errorf("asking for confirmation: %w", err)
@ -158,7 +215,7 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
} }
} }
} }
conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), flags.yes) conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), u.flags.yes)
if err != nil { if err != nil {
return err return err
} }
@ -168,21 +225,21 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
return fmt.Errorf("reading state file: %w", err) return fmt.Errorf("reading state file: %w", err)
} }
if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt, flags); err != nil { if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("upgrading measurements: %w", err) return fmt.Errorf("upgrading measurements: %w", err)
} }
// If infrastructure phase is skipped, we expect the new infrastructure // If infrastructure phase is skipped, we expect the new infrastructure
// to be in the Terraform configuration already. Otherwise, perform // to be in the Terraform configuration already. Otherwise, perform
// the Terraform migrations. // the Terraform migrations.
if !flags.skipPhases.contains(skipInfrastructurePhase) { if !u.flags.skipPhases.contains(skipInfrastructurePhase) {
migrationRequired, err := u.planTerraformMigration(cmd, conf) migrationRequired, err := u.planTerraformMigration(cmd, conf)
if err != nil { if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err) return fmt.Errorf("planning Terraform migrations: %w", err)
} }
if migrationRequired { if migrationRequired {
postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir, flags) postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir)
if err != nil { if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err) return fmt.Errorf("performing Terraform migrations: %w", err)
} }
@ -217,8 +274,8 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
} }
var upgradeErr *compatibility.InvalidUpgradeError var upgradeErr *compatibility.InvalidUpgradeError
if !flags.skipPhases.contains(skipHelmPhase) { if !u.flags.skipPhases.contains(skipHelmPhase) {
err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir, flags) err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir)
switch { switch {
case errors.As(err, &upgradeErr): case errors.As(err, &upgradeErr):
cmd.PrintErrln(err) cmd.PrintErrln(err)
@ -228,10 +285,10 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
return fmt.Errorf("upgrading services: %w", err) return fmt.Errorf("upgrading services: %w", err)
} }
} }
skipImageUpgrade := flags.skipPhases.contains(skipImagePhase) skipImageUpgrade := u.flags.skipPhases.contains(skipImagePhase)
skipK8sUpgrade := flags.skipPhases.contains(skipK8sPhase) skipK8sUpgrade := u.flags.skipPhases.contains(skipK8sPhase)
if !(skipImageUpgrade && skipK8sUpgrade) { if !(skipImageUpgrade && skipK8sUpgrade) {
err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, flags.force, skipImageUpgrade, skipK8sUpgrade) err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, u.flags.force, skipImageUpgrade, skipK8sUpgrade)
switch { switch {
case errors.Is(err, kubecmd.ErrInProgress): case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.") cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
@ -284,12 +341,11 @@ func (u *upgradeApplyCmd) planTerraformMigration(cmd *cobra.Command, conf *confi
// migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration // migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration
// of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned. // of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned.
// If no migration is required, the current (pre-upgrade) infrastructure state is returned. // If no migration is required, the current (pre-upgrade) infrastructure state is returned.
func (u *upgradeApplyCmd) migrateTerraform( func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string,
cmd *cobra.Command, conf *config.Config, upgradeDir string, flags upgradeApplyFlags,
) (state.Infrastructure, error) { ) (state.Infrastructure, error) {
// If there are any Terraform migrations to apply, ask for confirmation // If there are any Terraform migrations to apply, ask for confirmation
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.") fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !flags.yes { if !u.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?") ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil { if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err) return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
@ -317,8 +373,8 @@ func (u *upgradeApplyCmd) migrateTerraform(
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+ cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n", "A backup of the pre-upgrade state has been written to: %s\n",
flags.pf.PrefixPrintablePath(constants.StateFilename), u.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
flags.pf.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)), u.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
) )
return infraState, nil return infraState, nil
} }
@ -347,7 +403,7 @@ func validK8sVersion(cmd *cobra.Command, version string, yes bool) (validVersion
// confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements. // confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements.
// If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config. // If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config.
func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig( func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte, flags upgradeApplyFlags, cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte,
) error { ) error {
clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant()) clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil { if err != nil {
@ -369,7 +425,7 @@ func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
} }
cmd.Println("The following changes will be applied to the attestation config:") cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr) cmd.Println(diffStr)
if !flags.yes { if !u.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?") ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil { if err != nil {
return fmt.Errorf("asking for confirmation: %w", err) return fmt.Errorf("asking for confirmation: %w", err)
@ -387,21 +443,20 @@ func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
} }
func (u *upgradeApplyCmd) handleServiceUpgrade( func (u *upgradeApplyCmd) handleServiceUpgrade(
cmd *cobra.Command, conf *config.Config, stateFile *state.State, cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string,
upgradeDir string, flags upgradeApplyFlags,
) error { ) error {
var secret uri.MasterSecret var secret uri.MasterSecret
if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil { if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil {
return fmt.Errorf("reading master secret: %w", err) return fmt.Errorf("reading master secret: %w", err)
} }
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, flags.pf, u.log, u.fileHandler) serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, u.flags.pathPrefixer, u.log, u.fileHandler)
if err != nil { if err != nil {
return fmt.Errorf("getting service account URI: %w", err) return fmt.Errorf("getting service account URI: %w", err)
} }
options := helm.Options{ options := helm.Options{
Force: flags.force, Force: u.flags.force,
Conformance: flags.conformance, Conformance: u.flags.conformance,
HelmWaitMode: flags.helmWaitMode, HelmWaitMode: u.flags.helmWaitMode,
} }
prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) { prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) {
@ -422,7 +477,7 @@ func (u *upgradeApplyCmd) handleServiceUpgrade(
if !errors.Is(err, helm.ErrConfirmationMissing) { if !errors.Is(err, helm.ErrConfirmationMissing) {
return fmt.Errorf("upgrading charts with deny destructive mode: %w", err) return fmt.Errorf("upgrading charts with deny destructive mode: %w", err)
} }
if !flags.yes { if !u.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.") cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?") ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil { if askErr != nil {
@ -463,86 +518,6 @@ func (u *upgradeApplyCmd) handleServiceUpgrade(
return nil return nil
} }
func parseUpgradeApplyFlags(cmd *cobra.Command) (upgradeApplyFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return upgradeApplyFlags{}, err
}
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return upgradeApplyFlags{}, err
}
timeout, err := cmd.Flags().GetDuration("timeout")
if err != nil {
return upgradeApplyFlags{}, err
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
conformance, err := cmd.Flags().GetBool("conformance")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing conformance flag: %w", err)
}
skipHelmWait, err := cmd.Flags().GetBool("skip-helm-wait")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing skip-helm-wait flag: %w", err)
}
helmWaitMode := helm.WaitModeAtomic
if skipHelmWait {
helmWaitMode = helm.WaitModeNone
}
rawSkipPhases, err := cmd.Flags().GetStringSlice("skip-phases")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return upgradeApplyFlags{}, fmt.Errorf("invalid phase %s", phase)
}
}
return upgradeApplyFlags{
pf: pathprefix.New(workDir),
yes: yes,
upgradeTimeout: timeout,
force: force,
terraformLogLevel: logLevel,
conformance: conformance,
helmWaitMode: helmWaitMode,
skipPhases: skipPhases,
}, nil
}
type upgradeApplyFlags struct {
pf pathprefix.PathPrefixer
yes bool
upgradeTimeout time.Duration
force bool
terraformLogLevel terraform.LogLevel
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
// skipPhases is a list of phases that can be skipped during the upgrade process. // skipPhases is a list of phases that can be skipped during the upgrade process.
type skipPhases []skipPhase type skipPhases []skipPhase

View file

@ -227,10 +227,11 @@ func TestUpgradeApply(t *testing.T) {
clusterUpgrader: tc.terraformUpgrader, clusterUpgrader: tc.terraformUpgrader,
log: logger.NewTest(t), log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{}, configFetcher: stubAttestationFetcher{},
flags: tc.flags,
fileHandler: fh, fileHandler: fh,
} }
err := upgrader.upgradeApply(cmd, "test", tc.flags) err := upgrader.upgradeApply(cmd, "test")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
@ -247,16 +248,20 @@ func TestUpgradeApply(t *testing.T) {
} }
func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) { func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) {
require := require.New(t)
cmd := newUpgradeApplyCmd() cmd := newUpgradeApplyCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually // register persistent flags manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually cmd.Flags().String("workspace", "", "")
cmd.Flags().String("tf-log", "NONE", "") // register persistent flag manually cmd.Flags().Bool("force", true, "")
require.NoError(t, cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image")) cmd.Flags().String("tf-log", "NONE", "")
result, err := parseUpgradeApplyFlags(cmd) cmd.Flags().Bool("debug", false, "")
if err != nil {
t.Fatalf("Error while parsing flags: %v", err) require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image"))
}
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, result.skipPhases) var flags upgradeApplyFlags
err := flags.parse(cmd.Flags())
require.NoError(err)
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases)
} }
type stubKubernetesUpgrader struct { type stubKubernetesUpgrader struct {

View file

@ -38,6 +38,7 @@ import (
"github.com/siderolabs/talos/pkg/machinery/config/encoder" "github.com/siderolabs/talos/pkg/machinery/config/encoder"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
@ -57,6 +58,36 @@ func newUpgradeCheckCmd() *cobra.Command {
return cmd return cmd
} }
type upgradeCheckFlags struct {
rootFlags
updateConfig bool
ref string
stream string
}
func (f *upgradeCheckFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
updateConfig, err := flags.GetBool("update-config")
if err != nil {
return fmt.Errorf("getting 'update-config' flag: %w", err)
}
f.updateConfig = updateConfig
f.ref, err = flags.GetString("ref")
if err != nil {
return fmt.Errorf("getting 'ref' flag: %w", err)
}
f.stream, err = flags.GetString("stream")
if err != nil {
return fmt.Errorf("getting 'stream' flag: %w", err)
}
return nil
}
func runUpgradeCheck(cmd *cobra.Command, _ []string) error { func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd) log, err := newCLILogger(cmd)
if err != nil { if err != nil {
@ -64,8 +95,8 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
} }
defer log.Sync() defer log.Sync()
flags, err := parseUpgradeCheckFlags(cmd) var flags upgradeCheckFlags
if err != nil { if err := flags.parse(cmd.Flags()); err != nil {
return err return err
} }
@ -77,7 +108,7 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
cmd.Context(), cmd.Context(),
constants.TerraformWorkingDir, constants.TerraformWorkingDir,
upgradeDir, upgradeDir,
flags.terraformLogLevel, flags.tfLogLevel,
fileHandler, fileHandler,
) )
if err != nil { if err != nil {
@ -111,46 +142,11 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
upgradeDir: upgradeDir, upgradeDir: upgradeDir,
terraformChecker: tfClient, terraformChecker: tfClient,
fileHandler: fileHandler, fileHandler: fileHandler,
flags: flags,
log: log, log: log,
} }
return up.upgradeCheck(cmd, attestationconfigapi.NewFetcher(), flags) return up.upgradeCheck(cmd, attestationconfigapi.NewFetcher())
}
func parseUpgradeCheckFlags(cmd *cobra.Command) (upgradeCheckFlags, error) {
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing force bool: %w", err)
}
updateConfig, err := cmd.Flags().GetBool("update-config")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing update-config bool: %w", err)
}
ref, err := cmd.Flags().GetString("ref")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing ref string: %w", err)
}
stream, err := cmd.Flags().GetString("stream")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing stream string: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
return upgradeCheckFlags{
force: force,
updateConfig: updateConfig,
ref: ref,
stream: stream,
terraformLogLevel: logLevel,
}, nil
} }
type upgradeCheckCmd struct { type upgradeCheckCmd struct {
@ -159,12 +155,13 @@ type upgradeCheckCmd struct {
collect collector collect collector
terraformChecker terraformChecker terraformChecker terraformChecker
fileHandler file.Handler fileHandler file.Handler
flags upgradeCheckFlags
log debugLog log debugLog
} }
// upgradePlan plans an upgrade of a Constellation cluster. // upgradePlan plans an upgrade of a Constellation cluster.
func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationconfigapi.Fetcher, flags upgradeCheckFlags) error { func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationconfigapi.Fetcher) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, fetcher, flags.force) conf, err := config.New(u.fileHandler, constants.ConfigFilename, fetcher, u.flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -271,7 +268,7 @@ func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationco
// Using Print over Println as buildString already includes a trailing newline where necessary. // Using Print over Println as buildString already includes a trailing newline where necessary.
cmd.Print(updateMsg) cmd.Print(updateMsg)
if flags.updateConfig { if u.flags.updateConfig {
if err := upgrade.writeConfig(conf, u.fileHandler, constants.ConfigFilename); err != nil { if err := upgrade.writeConfig(conf, u.fileHandler, constants.ConfigFilename); err != nil {
return fmt.Errorf("writing config: %w", err) return fmt.Errorf("writing config: %w", err)
} }
@ -725,14 +722,6 @@ func (v *versionCollector) filterCompatibleCLIVersions(ctx context.Context, cliP
return compatibleVersions, nil return compatibleVersions, nil
} }
type upgradeCheckFlags struct {
force bool
updateConfig bool
ref string
stream string
terraformLogLevel terraform.LogLevel
}
type kubernetesChecker interface { type kubernetesChecker interface {
GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error) GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error)
} }

View file

@ -221,7 +221,7 @@ func TestUpgradeCheck(t *testing.T) {
cmd := newUpgradeCheckCmd() cmd := newUpgradeCheckCmd()
err := checkCmd.upgradeCheck(cmd, stubAttestationFetcher{}, upgradeCheckFlags{}) err := checkCmd.upgradeCheck(cmd, stubAttestationFetcher{})
if tc.wantError { if tc.wantError {
assert.Error(err) assert.Error(err)
return return

View file

@ -26,7 +26,6 @@ import (
tpmProto "github.com/google/go-tpm-tools/proto/tpm" tpmProto "github.com/google/go-tpm-tools/proto/tpm"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
@ -45,6 +44,7 @@ import (
"github.com/google/go-sev-guest/kds" "github.com/google/go-sev-guest/kds"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -64,7 +64,38 @@ func NewVerifyCmd() *cobra.Command {
return cmd return cmd
} }
type verifyFlags struct {
rootFlags
endpoint string
ownerID string
clusterID string
output string
}
func (f *verifyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
var err error
f.output, err = flags.GetString("output")
if err != nil {
return fmt.Errorf("getting 'output' flag: %w", err)
}
f.endpoint, err = flags.GetString("node-endpoint")
if err != nil {
return fmt.Errorf("getting 'node-endpoint' flag: %w", err)
}
f.clusterID, err = flags.GetString("cluster-id")
if err != nil {
return fmt.Errorf("getting 'cluster-id' flag: %w", err)
}
return nil
}
type verifyCmd struct { type verifyCmd struct {
fileHandler file.Handler
flags verifyFlags
log debugLog log debugLog
} }
@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return nil, fmt.Errorf("invalid output value for formatter: %s", output) return nil, fmt.Errorf("invalid output value for formatter: %s", output)
} }
} }
v := &verifyCmd{log: log} v := &verifyCmd{
fileHandler: fileHandler,
log: log,
}
if err := v.flags.parse(cmd.Flags()); err != nil {
return err
}
v.log.Debugf("Using flags: %+v", v.flags)
fetcher := attestationconfigapi.NewFetcher() fetcher := attestationconfigapi.NewFetcher()
return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher) return v.verify(cmd, verifyClient, formatterFactory, fetcher)
} }
type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error) type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error)
func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error { func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
flags, err := c.parseVerifyFlags(cmd, fileHandler) c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
if err != nil { conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force)
return fmt.Errorf("parsing flags: %w", err)
}
c.log.Debugf("Using flags: %+v", flags)
c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force)
var configValidationErr *config.ValidationError var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) { if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage()) cmd.PrintErrln(configValidationErr.LongMessage())
@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return fmt.Errorf("loading config file: %w", err) return fmt.Errorf("loading config file: %w", err)
} }
conf.UpdateMAAURL(flags.maaURL) stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
if err != nil {
return err
}
endpoint, err := c.validateEndpointFlag(cmd, stateFile)
if err != nil {
return err
}
var maaURL string
if stateFile.Infrastructure.Azure != nil {
maaURL = stateFile.Infrastructure.Azure.AttestationURL
}
conf.UpdateMAAURL(maaURL)
c.log.Debugf("Updating expected PCRs") c.log.Debugf("Updating expected PCRs")
attConfig := conf.GetAttestationConfig() attConfig := conf.GetAttestationConfig()
if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil { if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil {
return fmt.Errorf("updating expected PCRs: %w", err) return fmt.Errorf("updating expected PCRs: %w", err)
} }
@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc, err := verifyClient.Verify( rawAttestationDoc, err := verifyClient.Verify(
cmd.Context(), cmd.Context(),
flags.endpoint, endpoint,
&verifyproto.GetAttestationRequest{ &verifyproto.GetAttestationRequest{
Nonce: nonce, Nonce: nonce,
}, },
@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
} }
// certificates are only available for Azure // certificates are only available for Azure
formatter, err := factory(flags.output, conf.GetProvider(), c.log) formatter, err := factory(c.flags.output, conf.GetProvider(), c.log)
if err != nil { if err != nil {
return fmt.Errorf("creating formatter: %w", err) return fmt.Errorf("creating formatter: %w", err)
} }
@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc, rawAttestationDoc,
conf.Provider.Azure == nil, conf.Provider.Azure == nil,
attConfig.GetMeasurements(), attConfig.GetMeasurements(),
flags.maaURL, maaURL,
) )
if err != nil { if err != nil {
return fmt.Errorf("printing attestation document: %w", err) return fmt.Errorf("printing attestation document: %w", err)
@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return nil return nil
} }
func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) { func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) {
workDir, err := cmd.Flags().GetString("workspace") ownerID, clusterID = c.flags.ownerID, c.flags.clusterID
if err != nil { if c.flags.clusterID == "" {
return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err) cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
}
c.log.Debugf("Flag 'workspace' set to %q", workDir)
pf := pathprefix.New(workDir)
ownerID := ""
clusterID, err := cmd.Flags().GetString("cluster-id")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err)
}
c.log.Debugf("Flag 'cluster-id' set to %q", clusterID)
endpoint, err := cmd.Flags().GetString("node-endpoint")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
}
c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("Flag 'force' set to %t", force)
output, err := cmd.Flags().GetString("output")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err)
}
c.log.Debugf("Flag 'output' set to %t", output)
// Get empty values from state file
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
isFileNotFound := errors.Is(err, afero.ErrFileNotFound)
if isFileNotFound {
c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename))
stateFile = state.New() // error compat
} else if err != nil {
return verifyFlags{}, fmt.Errorf("reading state file: %w", err)
}
emptyEndpoint := endpoint == ""
emptyIDs := ownerID == "" && clusterID == ""
if emptyEndpoint || emptyIDs {
c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename))
if emptyEndpoint {
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
if emptyIDs {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
clusterID = stateFile.ClusterValues.ClusterID clusterID = stateFile.ClusterValues.ClusterID
} }
} if ownerID == "" {
// We don't want to print warnings until this is implemented again
var attestationURL string // cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if stateFile.Infrastructure.Azure != nil { ownerID = stateFile.ClusterValues.OwnerID
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
} }
// Validate // Validate
if ownerID == "" && clusterID == "" { if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster") return "", "", errors.New("cluster-id not provided to verify the cluster")
}
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
} }
return verifyFlags{ return ownerID, clusterID, nil
endpoint: endpoint,
pf: pf,
ownerID: ownerID,
clusterID: clusterID,
output: output,
maaURL: attestationURL,
force: force,
}, nil
} }
type verifyFlags struct { func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) {
endpoint string endpoint := c.flags.endpoint
ownerID string
clusterID string
maaURL string
output string
force bool
pf pathprefix.PathPrefixer
}
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" { if endpoint == "" {
return "", errors.New("endpoint is empty") cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return "", fmt.Errorf("validating endpoint argument: %w", err)
} }
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
} }
// an attestationDocFormatter formats the attestation document. // an attestationDocFormatter formats the attestation document.
@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) {
} }
return instanceInfo, nil return instanceInfo, nil
} }
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View file

@ -58,6 +58,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234", wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
}, },
@ -66,6 +67,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234", wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
}, },
@ -74,6 +76,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1", nodeEndpointFlag: "192.0.2.1",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
}, },
@ -81,6 +84,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
wantErr: true, wantErr: true,
}, },
@ -106,12 +110,14 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: ":::::", nodeEndpointFlag: ":::::",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
wantErr: true, wantErr: true,
}, },
"neither owner id nor cluster id set": { "neither owner id nor cluster id set": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
wantErr: true, wantErr: true,
}, },
@ -127,6 +133,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
skipConfigCreation: true, skipConfigCreation: true,
wantErr: true, wantErr: true,
@ -136,6 +143,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")}, protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
wantErr: true, wantErr: true,
}, },
@ -144,6 +152,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: someErr}, protoClient: &stubVerifyClient{verifyErr: someErr},
stateFile: state.New(),
formatter: &stubAttDocFormatter{}, formatter: &stubAttDocFormatter{},
wantErr: true, wantErr: true,
}, },
@ -152,6 +161,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234", nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64, clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{}, protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234", wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{formatErr: someErr}, formatter: &stubAttDocFormatter{formatErr: someErr},
wantErr: true, wantErr: true,
@ -164,31 +174,28 @@ func TestVerify(t *testing.T) {
require := require.New(t) require := require.New(t)
cmd := NewVerifyCmd() cmd := NewVerifyCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
out := &bytes.Buffer{} out := &bytes.Buffer{}
cmd.SetErr(out) cmd.SetErr(out)
if tc.clusterIDFlag != "" {
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
}
if tc.nodeEndpointFlag != "" {
require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag))
}
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
if !tc.skipConfigCreation { if !tc.skipConfigCreation {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider) cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg)) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
} }
if tc.stateFile != nil {
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
}
v := &verifyCmd{log: logger.NewTest(t)} v := &verifyCmd{
fileHandler: fileHandler,
log: logger.NewTest(t),
flags: verifyFlags{
clusterID: tc.clusterIDFlag,
endpoint: tc.nodeEndpointFlag,
},
}
formatterFac := func(_ string, _ cloudprovider.Provider, _ debugLog) (attestationDocFormatter, error) { formatterFac := func(_ string, _ cloudprovider.Provider, _ debugLog) (attestationDocFormatter, error) {
return tc.formatter, nil return tc.formatter, nil
} }
err := v.verify(cmd, fileHandler, tc.protoClient, formatterFac, stubAttestationFetcher{}) err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

5
go.mod
View file

@ -109,6 +109,7 @@ require (
github.com/sigstore/sigstore v1.7.1 github.com/sigstore/sigstore v1.7.1
github.com/spf13/afero v1.10.0 github.com/spf13/afero v1.10.0
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/theupdateframework/go-tuf v0.5.2 github.com/theupdateframework/go-tuf v0.5.2
github.com/tink-crypto/tink-go/v2 v2.0.0 github.com/tink-crypto/tink-go/v2 v2.0.0
@ -137,8 +138,6 @@ require (
sigs.k8s.io/yaml v1.3.0 sigs.k8s.io/yaml v1.3.0
) )
require github.com/google/go-tdx-guest v0.2.2 // indirect
require ( require (
cloud.google.com/go v0.110.2 // indirect cloud.google.com/go v0.110.2 // indirect
cloud.google.com/go/iam v1.1.0 // indirect cloud.google.com/go/iam v1.1.0 // indirect
@ -234,6 +233,7 @@ require (
github.com/google/go-attestation v0.5.0 // indirect github.com/google/go-attestation v0.5.0 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-containerregistry v0.15.2 // indirect github.com/google/go-containerregistry v0.15.2 // indirect
github.com/google/go-tdx-guest v0.2.2 // indirect
github.com/google/go-tspi v0.3.0 // indirect github.com/google/go-tspi v0.3.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect github.com/google/gofuzz v1.2.0 // indirect
github.com/google/logger v1.1.1 // indirect github.com/google/logger v1.1.1 // indirect
@ -306,7 +306,6 @@ require (
github.com/shopspring/decimal v1.3.1 // indirect github.com/shopspring/decimal v1.3.1 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spf13/cast v1.5.1 // indirect github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect github.com/stretchr/objx v0.5.0 // indirect
github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect
github.com/transparency-dev/merkle v0.0.2 // indirect github.com/transparency-dev/merkle v0.0.2 // indirect