cli: refactor flag parsing code (#2425)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-16 15:05:29 +02:00 committed by GitHub
parent adfe443b28
commit c52086c5ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1490 additions and 1726 deletions

View File

@ -13,7 +13,11 @@ go_library(
"configkubernetesversions.go",
"configmigrate.go",
"create.go",
"iam.go",
"iamcreate.go",
"iamcreateaws.go",
"iamcreateazure.go",
"iamcreategcp.go",
"iamdestroy.go",
"iamupgradeapply.go",
"init.go",
@ -87,6 +91,7 @@ go_library(
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
"@com_github_spf13_afero//:afero",
"@com_github_spf13_cobra//:cobra",
"@com_github_spf13_pflag//:pflag",
"@in_gopkg_yaml_v3//:yaml_v3",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/runtime",

View File

@ -20,3 +20,58 @@ Common filepaths are defined as constants in the global "/internal/constants" pa
To generate workspace correct filepaths for printing, use the functions from the "workspace" package.
*/
package cmd
import (
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/spf13/pflag"
)
// rootFlags are flags defined on the root command.
// They are available to all subcommands.
type rootFlags struct {
pathPrefixer pathprefix.PathPrefixer
tfLogLevel terraform.LogLevel
debug bool
force bool
}
// parse flags into the rootFlags struct.
func (f *rootFlags) parse(flags *pflag.FlagSet) error {
var errs error
workspace, err := flags.GetString("workspace")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'workspace' flag: %w", err))
}
f.pathPrefixer = pathprefix.New(workspace)
tfLogString, err := flags.GetString("tf-log")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'tf-log' flag: %w", err))
}
f.tfLogLevel, err = terraform.ParseLogLevel(tfLogString)
if err != nil {
errs = errors.Join(err, fmt.Errorf("parsing 'tf-log' flag: %w", err))
}
f.debug, err = flags.GetBool("debug")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'debug' flag: %w", err))
}
f.force, err = flags.GetBool("force")
if err != nil {
errs = errors.Join(err, fmt.Errorf("getting 'force' flag: %w", err))
}
return errs
}
func must(err error) {
if err != nil {
panic(err)
}
}

View File

@ -14,7 +14,6 @@ import (
"net/url"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
@ -26,6 +25,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/sigstore/keyselect"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
func newConfigFetchMeasurementsCmd() *cobra.Command {
@ -46,14 +46,35 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
}
type fetchMeasurementsFlags struct {
rootFlags
measurementsURL *url.URL
signatureURL *url.URL
insecure bool
force bool
pf pathprefix.PathPrefixer
}
func (f *fetchMeasurementsFlags) parse(flags *pflag.FlagSet) error {
var err error
if err := f.rootFlags.parse(flags); err != nil {
return err
}
f.measurementsURL, err = parseURLFlag(flags, "url")
if err != nil {
return err
}
f.signatureURL, err = parseURLFlag(flags, "signature-url")
if err != nil {
return err
}
f.insecure, err = flags.GetBool("insecure")
if err != nil {
return fmt.Errorf("getting 'insecure' flag: %w", err)
}
return nil
}
type configFetchMeasurementsCmd struct {
flags fetchMeasurementsFlags
canFetchMeasurements bool
log debugLog
}
@ -70,6 +91,10 @@ func runConfigFetchMeasurements(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("constructing Rekor client: %w", err)
}
cfm := &configFetchMeasurementsCmd{log: log, canFetchMeasurements: featureset.CanFetchMeasurements}
if err := cfm.flags.parse(cmd.Flags()); err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
cfm.log.Debugf("Using flags %+v", cfm.flags)
fetcher := attestationconfigapi.NewFetcherWithClient(http.DefaultClient, constants.CDNRepositoryURL)
return cfm.configFetchMeasurements(cmd, sigstore.NewCosignVerifier, rekor, fileHandler, fetcher, http.DefaultClient)
@ -79,20 +104,14 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
cmd *cobra.Command, newCosignVerifier cosignVerifierConstructor, rekor rekorVerifier,
fileHandler file.Handler, fetcher attestationconfigapi.Fetcher, client *http.Client,
) error {
flags, err := cfm.parseFetchMeasurementsFlags(cmd)
if err != nil {
return err
}
cfm.log.Debugf("Using flags %v", flags)
if !cfm.canFetchMeasurements {
cmd.PrintErrln("Fetching measurements is not supported in the OSS build of the Constellation CLI. Consult the documentation for instructions on where to download the enterprise version.")
return errors.New("fetching measurements is not supported")
}
cfm.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
cfm.log.Debugf("Loading configuration file from %q", cfm.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force)
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, cfm.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -110,7 +129,7 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
defer cancel()
cfm.log.Debugf("Updating URLs")
if err := flags.updateURLs(conf); err != nil {
if err := cfm.flags.updateURLs(conf); err != nil {
return err
}
@ -131,11 +150,11 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
var fetchedMeasurements measurements.M
var hash string
if flags.insecure {
if cfm.flags.insecure {
if err := fetchedMeasurements.FetchNoVerify(
ctx,
client,
flags.measurementsURL,
cfm.flags.measurementsURL,
imageVersion,
conf.GetProvider(),
conf.GetAttestationConfig().GetVariant(),
@ -149,8 +168,8 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
ctx,
client,
cosign,
flags.measurementsURL,
flags.signatureURL,
cfm.flags.measurementsURL,
cfm.flags.signatureURL,
imageVersion,
conf.GetProvider(),
conf.GetAttestationConfig().GetVariant(),
@ -173,63 +192,11 @@ func (cfm *configFetchMeasurementsCmd) configFetchMeasurements(
if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil {
return err
}
cfm.log.Debugf("Configuration written to %s", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
cfm.log.Debugf("Configuration written to %s", cfm.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
cmd.Print("Successfully fetched measurements and updated Configuration\n")
return nil
}
// parseURLFlag checks that flag can be parsed as URL.
// If no value was provided for flag, nil is returned.
func (cfm *configFetchMeasurementsCmd) parseURLFlag(cmd *cobra.Command, flag string) (*url.URL, error) {
rawURL, err := cmd.Flags().GetString(flag)
if err != nil {
return nil, fmt.Errorf("parsing config generate flags '%s': %w", flag, err)
}
cfm.log.Debugf("Flag %s has raw URL %q", flag, rawURL)
if rawURL != "" {
cfm.log.Debugf("Parsing raw URL")
return url.Parse(rawURL)
}
return nil, nil
}
func (cfm *configFetchMeasurementsCmd) parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return nil, fmt.Errorf("parsing workspace argument: %w", err)
}
measurementsURL, err := cfm.parseURLFlag(cmd, "url")
if err != nil {
return nil, err
}
cfm.log.Debugf("Parsed measurements URL as %v", measurementsURL)
measurementsSignatureURL, err := cfm.parseURLFlag(cmd, "signature-url")
if err != nil {
return nil, err
}
cfm.log.Debugf("Parsed measurements signature URL as %v", measurementsSignatureURL)
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return nil, fmt.Errorf("parsing insecure argument: %w", err)
}
cfm.log.Debugf("Insecure flag is %v", insecure)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return nil, fmt.Errorf("parsing force argument: %w", err)
}
return &fetchMeasurementsFlags{
measurementsURL: measurementsURL,
signatureURL: measurementsSignatureURL,
insecure: insecure,
force: force,
pf: pathprefix.New(workDir),
}, nil
}
func (f *fetchMeasurementsFlags) updateURLs(conf *config.Config) error {
ver, err := versionsapi.NewVersionFromShortPath(conf.Image, versionsapi.VersionKindImage)
if err != nil {
@ -250,6 +217,19 @@ func (f *fetchMeasurementsFlags) updateURLs(conf *config.Config) error {
return nil
}
// parseURLFlag checks that flag can be parsed as URL.
// If no value was provided for flag, nil is returned.
func parseURLFlag(flags *pflag.FlagSet, flag string) (*url.URL, error) {
rawURL, err := flags.GetString(flag)
if err != nil {
return nil, fmt.Errorf("getting '%s' flag: %w", flag, err)
}
if rawURL != "" {
return url.Parse(rawURL)
}
return nil, nil
}
type rekorVerifier interface {
SearchByHash(context.Context, string) ([]string, error)
VerifyEntry(context.Context, string, string) error

View File

@ -13,7 +13,6 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"testing"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
@ -39,11 +38,11 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
urlFlag string
signatureURLFlag string
forceFlag bool
wantFlags *fetchMeasurementsFlags
wantFlags fetchMeasurementsFlags
wantErr bool
}{
"default": {
wantFlags: &fetchMeasurementsFlags{
wantFlags: fetchMeasurementsFlags{
measurementsURL: nil,
signatureURL: nil,
},
@ -51,7 +50,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
"url": {
urlFlag: "https://some.other.url/with/path",
signatureURLFlag: "https://some.other.url/with/path.sig",
wantFlags: &fetchMeasurementsFlags{
wantFlags: fetchMeasurementsFlags{
measurementsURL: urlMustParse("https://some.other.url/with/path"),
signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
},
@ -69,7 +68,9 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
cmd := newConfigFetchMeasurementsCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", false, "") // register persistent flag manually
cmd.Flags().Bool("force", false, "")
cmd.Flags().Bool("debug", false, "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.urlFlag != "" {
require.NoError(cmd.Flags().Set("url", tc.urlFlag))
@ -77,8 +78,8 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
if tc.signatureURLFlag != "" {
require.NoError(cmd.Flags().Set("signature-url", tc.signatureURLFlag))
}
cfm := &configFetchMeasurementsCmd{log: logger.NewTest(t)}
flags, err := cfm.parseFetchMeasurementsFlags(cmd)
var flags fetchMeasurementsFlags
err := flags.parse(cmd.Flags())
if tc.wantErr {
assert.Error(err)
return
@ -270,9 +271,6 @@ func TestConfigFetchMeasurements(t *testing.T) {
require := require.New(t)
cmd := newConfigFetchMeasurementsCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
require.NoError(cmd.Flags().Set("insecure", strconv.FormatBool(tc.insecureFlag)))
fileHandler := file.NewHandler(afero.NewMemMapFs())
gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
@ -281,6 +279,8 @@ func TestConfigFetchMeasurements(t *testing.T) {
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)
require.NoError(err)
cfm := &configFetchMeasurementsCmd{canFetchMeasurements: true, log: logger.NewTest(t)}
cfm.flags.insecure = tc.insecureFlag
cfm.flags.force = true
err = cfm.configFetchMeasurements(cmd, tc.cosign, tc.rekor, fileHandler, stubAttestationFetcher{}, client)
if tc.wantErr {

View File

@ -10,7 +10,6 @@ import (
"fmt"
"strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
@ -19,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/mod/semver"
)
@ -41,13 +41,34 @@ func newConfigGenerateCmd() *cobra.Command {
}
type generateFlags struct {
pf pathprefix.PathPrefixer
rootFlags
k8sVersion versions.ValidK8sVersion
attestationVariant variant.Variant
}
func (f *generateFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
k8sVersion, err := parseK8sFlag(flags)
if err != nil {
return err
}
f.k8sVersion = k8sVersion
variant, err := parseAttestationFlag(flags)
if err != nil {
return err
}
f.attestationVariant = variant
return nil
}
type configGenerateCmd struct {
log debugLog
flags generateFlags
log debugLog
}
func runConfigGenerate(cmd *cobra.Command, args []string) error {
@ -56,31 +77,32 @@ func runConfigGenerate(cmd *cobra.Command, args []string) error {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs())
provider := cloudprovider.FromString(args[0])
cg := &configGenerateCmd{log: log}
if err := cg.flags.parse(cmd.Flags()); err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
log.Debugf("Parsed flags as %+v", cg.flags)
return cg.configGenerate(cmd, fileHandler, provider, args[0])
}
func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file.Handler, provider cloudprovider.Provider, rawProvider string) error {
flags, err := parseGenerateFlags(cmd)
if err != nil {
return err
}
cg.log.Debugf("Parsed flags as %v", flags)
cg.log.Debugf("Using cloud provider %s", provider.String())
conf, err := createConfigWithAttestationVariant(provider, rawProvider, flags.attestationVariant)
conf, err := createConfigWithAttestationVariant(provider, rawProvider, cg.flags.attestationVariant)
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
conf.KubernetesVersion = flags.k8sVersion
conf.KubernetesVersion = cg.flags.k8sVersion
cg.log.Debugf("Writing YAML data to configuration file")
if err := fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptMkdirAll); err != nil {
return err
}
cmd.Println("Config file written to", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
cmd.Println("Config file written to", cg.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
cmd.Println("Please fill in your CSP-specific configuration before proceeding.")
cmd.Println("For more information refer to the documentation:")
cmd.Println("\thttps://docs.edgeless.systems/constellation/getting-started/first-steps")
@ -123,46 +145,6 @@ func createConfig(provider cloudprovider.Provider) *config.Config {
return res
}
func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing workspace flag: %w", err)
}
k8sVersion, err := cmd.Flags().GetString("kubernetes")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing Kubernetes flag: %w", err)
}
resolvedVersion, err := versions.ResolveK8sPatchVersion(k8sVersion)
if err != nil {
return generateFlags{}, fmt.Errorf("resolving kubernetes patch version from flag: %w", err)
}
validK8sVersion, err := versions.NewValidK8sVersion(resolvedVersion, true)
if err != nil {
return generateFlags{}, fmt.Errorf("resolving Kubernetes version from flag: %w", err)
}
attestationString, err := cmd.Flags().GetString("attestation")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing attestation flag: %w", err)
}
var attestationVariant variant.Variant
// if no attestation variant is specified, use the default for the cloud provider
if attestationString == "" {
attestationVariant = variant.Dummy{}
} else {
attestationVariant, err = variant.FromString(attestationString)
if err != nil {
return generateFlags{}, fmt.Errorf("invalid attestation variant: %s", attestationString)
}
}
return generateFlags{
pf: pathprefix.New(workDir),
k8sVersion: validK8sVersion,
attestationVariant: attestationVariant,
}, nil
}
// generateCompletion handles the completion of the create command. It is frequently called
// while the user types arguments of the command to suggest completion.
func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) {
@ -185,3 +167,39 @@ func toString[T any](t []T) []string {
}
return res
}
func parseK8sFlag(flags *pflag.FlagSet) (versions.ValidK8sVersion, error) {
versionString, err := flags.GetString("kubernetes")
if err != nil {
return "", fmt.Errorf("getting kubernetes flag: %w", err)
}
resolvedVersion, err := versions.ResolveK8sPatchVersion(versionString)
if err != nil {
return "", fmt.Errorf("resolving kubernetes patch version from flag: %w", err)
}
k8sVersion, err := versions.NewValidK8sVersion(resolvedVersion, true)
if err != nil {
return "", fmt.Errorf("resolving Kubernetes version from flag: %w", err)
}
return k8sVersion, nil
}
func parseAttestationFlag(flags *pflag.FlagSet) (variant.Variant, error) {
attestationString, err := flags.GetString("attestation")
if err != nil {
return nil, fmt.Errorf("getting attestation flag: %w", err)
}
var attestationVariant variant.Variant
// if no attestation variant is specified, use the default for the cloud provider
if attestationString == "" {
attestationVariant = variant.Dummy{}
} else {
attestationVariant, err = variant.FromString(attestationString)
if err != nil {
return nil, fmt.Errorf("invalid attestation variant: %s", attestationString)
}
}
return attestationVariant, nil
}

View File

@ -19,13 +19,12 @@ import (
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/mod/semver"
)
func TestConfigGenerateKubernetesVersion(t *testing.T) {
func TestParseKubernetesVersion(t *testing.T) {
testCases := map[string]struct {
version string
wantErr bool
@ -68,22 +67,18 @@ func TestConfigGenerateKubernetesVersion(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
flags := newConfigGenerateCmd().Flags()
if tc.version != "" {
err := cmd.Flags().Set("kubernetes", tc.version)
require.NoError(err)
require.NoError(flags.Set("kubernetes", tc.version))
}
cg := &configGenerateCmd{log: logger.NewTest(t)}
err := cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")
version, err := parseK8sFlag(flags)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(versions.Default, version)
})
}
}
@ -94,9 +89,14 @@ func TestConfigGenerateDefault(t *testing.T) {
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cg := &configGenerateCmd{log: logger.NewTest(t)}
cg := &configGenerateCmd{
log: logger.NewTest(t),
flags: generateFlags{
attestationVariant: variant.Dummy{},
k8sVersion: versions.Default,
},
}
require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, ""))
var readConfig config.Config
@ -106,53 +106,47 @@ func TestConfigGenerateDefault(t *testing.T) {
}
func TestConfigGenerateDefaultProviderSpecific(t *testing.T) {
providers := []cloudprovider.Provider{
cloudprovider.AWS,
cloudprovider.Azure,
cloudprovider.GCP,
cloudprovider.OpenStack,
testCases := map[string]struct {
provider cloudprovider.Provider
rawProvider string
}{
"aws": {
provider: cloudprovider.AWS,
},
"azure": {
provider: cloudprovider.Azure,
},
"gcp": {
provider: cloudprovider.GCP,
},
"openstack": {
provider: cloudprovider.OpenStack,
},
"stackit": {
provider: cloudprovider.OpenStack,
rawProvider: "stackit",
},
}
for _, provider := range providers {
t.Run(provider.String(), func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
wantConf := config.Default()
wantConf.RemoveProviderAndAttestationExcept(provider)
wantConf := config.Default().WithOpenStackProviderDefaults(tc.rawProvider)
wantConf.RemoveProviderAndAttestationExcept(tc.provider)
cg := &configGenerateCmd{log: logger.NewTest(t)}
require.NoError(cg.configGenerate(cmd, fileHandler, provider, ""))
var readConfig config.Config
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
assert.NoError(err)
assert.Equal(*wantConf, readConfig)
})
}
}
func TestConfigGenerateWithStackIt(t *testing.T) {
openStackProviders := []string{"stackit"}
for _, openStackProvider := range openStackProviders {
t.Run(openStackProvider, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
wantConf := config.Default().WithOpenStackProviderDefaults(openStackProvider)
wantConf.RemoveProviderAndAttestationExcept(cloudprovider.OpenStack)
cg := &configGenerateCmd{log: logger.NewTest(t)}
require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.OpenStack, openStackProvider))
cg := &configGenerateCmd{
log: logger.NewTest(t),
flags: generateFlags{
attestationVariant: variant.Dummy{},
k8sVersion: versions.Default,
},
}
require.NoError(cg.configGenerate(cmd, fileHandler, tc.provider, tc.rawProvider))
var readConfig config.Config
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
@ -168,9 +162,11 @@ func TestConfigGenerateDefaultExists(t *testing.T) {
fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone))
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cg := &configGenerateCmd{log: logger.NewTest(t)}
cg := &configGenerateCmd{
log: logger.NewTest(t),
flags: generateFlags{attestationVariant: variant.Dummy{}},
}
require.Error(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, ""))
}
@ -247,64 +243,61 @@ func TestValidProviderAttestationCombination(t *testing.T) {
}
}
func TestAttestationArgument(t *testing.T) {
defaultAttestation := config.Default().Attestation
tests := []struct {
name string
provider cloudprovider.Provider
expectErr bool
expectedCfg config.AttestationConfig
setFlag func(*cobra.Command) error
func TestParseAttestationFlag(t *testing.T) {
testCases := map[string]struct {
wantErr bool
attestationFlag string
wantVariant variant.Variant
}{
{
name: "InvalidAttestationArgument",
provider: cloudprovider.Unknown,
expectErr: true,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "unknown")
},
"invalid": {
wantErr: true,
attestationFlag: "unknown",
},
{
name: "ValidAttestationArgument",
provider: cloudprovider.Azure,
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "azure-trustedlaunch")
},
expectedCfg: config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch},
"AzureTrustedLaunch": {
attestationFlag: "azure-trustedlaunch",
wantVariant: variant.AzureTrustedLaunch{},
},
{
name: "WithoutAttestationArgument",
provider: cloudprovider.Azure,
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return nil
},
expectedCfg: config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP},
"AzureSEVSNP": {
attestationFlag: "azure-sev-snp",
wantVariant: variant.AzureSEVSNP{},
},
"AWSSEVSNP": {
attestationFlag: "aws-sev-snp",
wantVariant: variant.AWSSEVSNP{},
},
"AWSNitroTPM": {
attestationFlag: "aws-nitro-tpm",
wantVariant: variant.AWSNitroTPM{},
},
"GCPSEVES": {
attestationFlag: "gcp-sev-es",
wantVariant: variant.GCPSEVES{},
},
"QEMUVTPM": {
attestationFlag: "qemu-vtpm",
wantVariant: variant.QEMUVTPM{},
},
"no flag": {
wantVariant: variant.Dummy{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require := assert.New(t)
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
cmd := newConfigGenerateCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
require.NoError(test.setFlag(cmd))
fileHandler := file.NewHandler(afero.NewMemMapFs())
cg := &configGenerateCmd{log: logger.NewTest(t)}
err := cg.configGenerate(cmd, fileHandler, test.provider, "")
if test.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
var readConfig config.Config
require.NoError(fileHandler.ReadYAML(constants.ConfigFilename, &readConfig))
assert.Equal(test.expectedCfg, readConfig.Attestation)
if tc.attestationFlag != "" {
require.NoError(cmd.Flags().Set("attestation", tc.attestationFlag))
}
attestation, err := parseAttestationFlag(cmd.Flags())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.True(tc.wantVariant.Equal(attestation))
})
}
}

View File

@ -24,6 +24,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/semver"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// NewCreateCmd returns a new cobra.Command for the create command.
@ -39,9 +40,29 @@ func NewCreateCmd() *cobra.Command {
return cmd
}
// createFlags contains the parsed flags of the create command.
type createFlags struct {
rootFlags
yes bool
}
// parse parses the flags of the create command.
func (f *createFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
type createCmd struct {
log debugLog
pf pathprefix.PathPrefixer
log debugLog
flags createFlags
}
func runCreate(cmd *cobra.Command, _ []string) error {
@ -59,22 +80,22 @@ func runCreate(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
creator := cloudcmd.NewCreator(spinner)
c := &createCmd{log: log}
if err := c.flags.parse(cmd.Flags()); err != nil {
return err
}
c.log.Debugf("Using flags: %+v", c.flags)
fetcher := attestationconfigapi.NewFetcher()
return c.create(cmd, creator, fileHandler, spinner, fetcher)
}
func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, spinner spinnerInterf, fetcher attestationconfigapi.Fetcher) (retErr error) {
flags, err := c.parseCreateFlags(cmd)
if err != nil {
return err
}
c.log.Debugf("Using flags: %+v", flags)
if err := c.checkDirClean(fileHandler); err != nil {
return err
}
c.log.Debugf("Loading configuration file from %q", c.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force)
c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, c.flags.force)
c.log.Debugf("Configuration file loaded: %+v", conf)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
@ -83,7 +104,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
if err != nil {
return err
}
if !flags.force {
if !c.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err
}
@ -137,7 +158,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
c.log.Debugf("Creating %d additional node groups: %v", len(otherGroupNames), otherGroupNames)
}
if !flags.yes {
if !c.flags.yes {
// Ask user to confirm action.
cmd.Printf("The following Constellation cluster will be created:\n")
cmd.Printf(" %d control-plane node%s of type %s will be created.\n", controlPlaneGroup.InitialCount, isPlural(controlPlaneGroup.InitialCount), controlPlaneGroup.InstanceType)
@ -160,13 +181,13 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
opts := cloudcmd.CreateOptions{
Provider: provider,
Config: conf,
TFLogLevel: flags.tfLogLevel,
TFLogLevel: c.flags.tfLogLevel,
TFWorkspace: constants.TerraformWorkingDir,
}
infraState, err := creator.Create(cmd.Context(), opts)
spinner.Stop()
if err != nil {
return translateCreateErrors(cmd, c.pf, err)
return translateCreateErrors(cmd, c.flags.pathPrefixer, err)
}
c.log.Debugf("Successfully created the cloud resources for the cluster")
@ -179,64 +200,28 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
return nil
}
// parseCreateFlags parses the flags of the create command.
func (c *createCmd) parseCreateFlags(cmd *cobra.Command) (createFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return createFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
c.log.Debugf("Yes flag is %t", yes)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return createFlags{}, fmt.Errorf("parsing config path argument: %w", err)
}
c.log.Debugf("Workspace set to %q", workDir)
c.pf = pathprefix.New(workDir)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return createFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("force flag is %t", force)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return createFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return createFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
c.log.Debugf("Terraform logs will be written into %s at level %s", c.pf.PrefixPrintablePath(constants.TerraformLogFile), logLevel.String())
return createFlags{
tfLogLevel: logLevel,
force: force,
yes: yes,
}, nil
}
// createFlags contains the parsed flags of the create command.
type createFlags struct {
tfLogLevel terraform.LogLevel
force bool
yes bool
}
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
func (c *createCmd) checkDirClean(fileHandler file.Handler) error {
c.log.Debugf("Checking admin configuration file")
if _, err := fileHandler.Stat(constants.AdminConfFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", c.pf.PrefixPrintablePath(constants.AdminConfFilename))
return fmt.Errorf(
"file '%s' already exists in working directory, run 'constellation terminate' before creating a new one",
c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename),
)
}
c.log.Debugf("Checking master secrets file")
if _, err := fileHandler.Stat(constants.MasterSecretFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory. Constellation won't overwrite previous master secrets. Move it somewhere or delete it before creating a new cluster", c.pf.PrefixPrintablePath(constants.MasterSecretFilename))
return fmt.Errorf(
"file '%s' already exists in working directory. Constellation won't overwrite previous master secrets. Move it somewhere or delete it before creating a new cluster",
c.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename),
)
}
c.log.Debugf("Checking state file")
if _, err := fileHandler.Stat(constants.StateFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory. Constellation won't overwrite previous cluster state. Move it somewhere or delete it before creating a new cluster", c.pf.PrefixPrintablePath(constants.StateFilename))
return fmt.Errorf(
"file '%s' already exists in working directory. Constellation won't overwrite previous cluster state. Move it somewhere or delete it before creating a new cluster",
c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
)
}
return nil
@ -270,12 +255,6 @@ func isPlural(count int) string {
return "s"
}
func must(err error) {
if err != nil {
panic(err)
}
}
// validateCLIandConstellationVersionAreEqual checks if the image and microservice version are equal (down to patch level) to the CLI version.
func validateCLIandConstellationVersionAreEqual(cliVersion semver.Semver, imageVersion string, microserviceVersion semver.Semver) error {
parsedImageVersion, err := versionsapi.NewVersionFromShortPath(imageVersion, versionsapi.VersionKindImage)

View File

@ -133,16 +133,9 @@ func TestCreate(t *testing.T) {
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
cmd.Flags().String("tf-log", "NONE", "") // register persistent flag manually
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider))
c := &createCmd{log: logger.NewTest(t)}
c := &createCmd{log: logger.NewTest(t), flags: createFlags{yes: tc.yesFlag}}
err := c.create(cmd, tc.creator, fileHandler, &nopSpinner{}, stubAttestationFetcher{})
if tc.wantErr {

23
cli/internal/cmd/iam.go Normal file
View File

@ -0,0 +1,23 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import "github.com/spf13/cobra"
// NewIAMCmd returns a new cobra.Command for the iam parent command. It needs another verb and does nothing on its own.
func NewIAMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "iam",
Short: "Work with the IAM configuration on your cloud provider",
Long: "Work with the IAM configuration on your cloud provider.",
Args: cobra.ExactArgs(0),
}
cmd.AddCommand(newIAMCreateCmd())
cmd.AddCommand(newIAMDestroyCmd())
cmd.AddCommand(newIAMUpgradeCmd())
return cmd
}

View File

@ -11,17 +11,15 @@ import (
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
var (
@ -33,22 +31,7 @@ var (
gcpIDRegex = regexp.MustCompile(`^[a-z][-a-z0-9]{4,28}[a-z0-9]$`)
)
// NewIAMCmd returns a new cobra.Command for the iam parent command. It needs another verb and does nothing on its own.
func NewIAMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "iam",
Short: "Work with the IAM configuration on your cloud provider",
Long: "Work with the IAM configuration on your cloud provider.",
Args: cobra.ExactArgs(0),
}
cmd.AddCommand(newIAMCreateCmd())
cmd.AddCommand(newIAMDestroyCmd())
cmd.AddCommand(newIAMUpgradeCmd())
return cmd
}
// NewIAMCreateCmd returns a new cobra.Command for the iam create parent command. It needs another verb, and does nothing on its own.
// newIAMCreateCmd returns a new cobra.Command for the iam create parent command. It needs another verb, and does nothing on its own.
func newIAMCreateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "create",
@ -67,135 +50,54 @@ func newIAMCreateCmd() *cobra.Command {
return cmd
}
// newIAMCreateAWSCmd returns a new cobra.Command for the iam create aws command.
func newIAMCreateAWSCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "aws",
Short: "Create IAM configuration on AWS for your Constellation cluster",
Long: "Create IAM configuration on AWS for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.AWS),
}
cmd.Flags().String("prefix", "", "name prefix for all resources (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "prefix"))
cmd.Flags().String("zone", "", "AWS availability zone the resources will be created in, e.g., us-east-2a (required)\n"+
"See the Constellation docs for a list of currently supported regions.")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
return cmd
type iamCreateFlags struct {
rootFlags
yes bool
updateConfig bool
}
// newIAMCreateAzureCmd returns a new cobra.Command for the iam create azure command.
func newIAMCreateAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure",
Short: "Create IAM configuration on Microsoft Azure for your Constellation cluster",
Long: "Create IAM configuration on Microsoft Azure for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.Azure),
func (f *iamCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
if err = f.rootFlags.parse(flags); err != nil {
return err
}
cmd.Flags().String("resourceGroup", "", "name prefix of the two resource groups your cluster / IAM resources will be created in (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "resourceGroup"))
cmd.Flags().String("region", "", "region the resources will be created in, e.g., westus (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "region"))
cmd.Flags().String("servicePrincipal", "", "name of the service principal that will be created (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "servicePrincipal"))
return cmd
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.updateConfig, err = flags.GetBool("update-config")
if err != nil {
return fmt.Errorf("getting 'update-config' flag: %w", err)
}
return nil
}
// NewIAMCreateGCPCmd returns a new cobra.Command for the iam create gcp command.
func newIAMCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create IAM configuration on GCP for your Constellation cluster",
Long: "Create IAM configuration on GCP for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: createRunIAMFunc(cloudprovider.GCP),
}
cmd.Flags().String("zone", "", "GCP zone the cluster will be deployed in (required)\n"+
"Find a list of available zones here: https://cloud.google.com/compute/docs/regions-zones#available")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
cmd.Flags().String("serviceAccountID", "", "ID for the service account that will be created (required)\n"+
"Must be 6 to 30 lowercase letters, digits, or hyphens.")
must(cobra.MarkFlagRequired(cmd.Flags(), "serviceAccountID"))
cmd.Flags().String("projectID", "", "ID of the GCP project the configuration will be created in (required)\n"+
"Find it on the welcome screen of your project: https://console.cloud.google.com/welcome")
must(cobra.MarkFlagRequired(cmd.Flags(), "projectID"))
return cmd
}
// createRunIAMFunc is the entrypoint for the iam create command. It sets up the iamCreator
// and starts IAM creation for the specific cloud provider.
func createRunIAMFunc(provider cloudprovider.Provider) func(cmd *cobra.Command, args []string) error {
var providerCreator func(pf pathprefix.PathPrefixer) providerIAMCreator
switch provider {
case cloudprovider.AWS:
providerCreator = func(pathprefix.PathPrefixer) providerIAMCreator { return &awsIAMCreator{} }
case cloudprovider.Azure:
providerCreator = func(pathprefix.PathPrefixer) providerIAMCreator { return &azureIAMCreator{} }
case cloudprovider.GCP:
providerCreator = func(pf pathprefix.PathPrefixer) providerIAMCreator {
return &gcpIAMCreator{pf}
}
default:
return func(cmd *cobra.Command, args []string) error {
return fmt.Errorf("unknown provider %s", provider)
}
}
return func(cmd *cobra.Command, args []string) error {
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return fmt.Errorf("parsing workspace string: %w", err)
}
pf := pathprefix.New(workDir)
iamCreator, err := newIAMCreator(cmd, pf, logLevel)
if err != nil {
return fmt.Errorf("creating iamCreator: %w", err)
}
defer iamCreator.spinner.Stop()
defer iamCreator.log.Sync()
iamCreator.provider = provider
iamCreator.providerCreator = providerCreator(pf)
return iamCreator.create(cmd.Context())
}
}
// newIAMCreator creates a new iamiamCreator.
func newIAMCreator(cmd *cobra.Command, pf pathprefix.PathPrefixer, logLevel terraform.LogLevel) (*iamCreator, error) {
func runIAMCreate(cmd *cobra.Command, providerCreator providerIAMCreator, provider cloudprovider.Provider) error {
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return nil, fmt.Errorf("creating spinner: %w", err)
return fmt.Errorf("creating spinner: %w", err)
}
defer spinner.Stop()
log, err := newCLILogger(cmd)
if err != nil {
return nil, fmt.Errorf("creating logger: %w", err)
return fmt.Errorf("creating logger: %w", err)
}
log.Debugf("Terraform logs will be written into %s at level %s", pf.PrefixPrintablePath(constants.TerraformLogFile), logLevel.String())
defer log.Sync()
return &iamCreator{
cmd: cmd,
spinner: spinner,
log: log,
creator: cloudcmd.NewIAMCreator(spinner),
fileHandler: file.NewHandler(afero.NewOsFs()),
iamConfig: &cloudcmd.IAMConfigOptions{
TFWorkspace: constants.TerraformIAMWorkingDir,
TFLogLevel: logLevel,
},
}, nil
iamCreator := &iamCreator{
cmd: cmd,
spinner: spinner,
log: log,
creator: cloudcmd.NewIAMCreator(spinner),
fileHandler: file.NewHandler(afero.NewOsFs()),
providerCreator: providerCreator,
provider: provider,
}
if err := iamCreator.flags.parse(cmd.Flags()); err != nil {
return err
}
return iamCreator.create(cmd.Context())
}
// iamCreator is the iamCreator for the iam create command.
@ -208,24 +110,18 @@ type iamCreator struct {
providerCreator providerIAMCreator
iamConfig *cloudcmd.IAMConfigOptions
log debugLog
pf pathprefix.PathPrefixer
flags iamCreateFlags
}
// create IAM configuration on the iamCreator's cloud provider.
func (c *iamCreator) create(ctx context.Context) error {
flags, err := c.parseFlagsAndSetupConfig()
if err != nil {
return err
}
c.log.Debugf("Using flags: %+v", flags)
if err := c.checkWorkingDir(); err != nil {
return err
}
if !flags.yesFlag {
if !c.flags.yes {
c.cmd.Printf("The following IAM configuration will be created:\n\n")
c.providerCreator.printConfirmValues(c.cmd, flags)
c.providerCreator.printConfirmValues(c.cmd)
ok, err := askToConfirm(c.cmd, "Do you want to create the configuration?")
if err != nil {
return err
@ -237,19 +133,22 @@ func (c *iamCreator) create(ctx context.Context) error {
}
var conf config.Config
if flags.updateConfig {
c.log.Debugf("Parsing config %s", c.pf.PrefixPrintablePath(constants.ConfigFilename))
if err = c.fileHandler.ReadYAML(constants.ConfigFilename, &conf); err != nil {
if c.flags.updateConfig {
c.log.Debugf("Parsing config %s", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
if err := c.fileHandler.ReadYAML(constants.ConfigFilename, &conf); err != nil {
return fmt.Errorf("error reading the configuration file: %w", err)
}
if err := validateConfigWithFlagCompatibility(c.provider, conf, flags); err != nil {
if err := c.providerCreator.validateConfigWithFlagCompatibility(conf); err != nil {
return err
}
c.cmd.Printf("The configuration file %q will be automatically updated with the IAM values and zone/region information.\n", c.pf.PrefixPrintablePath(constants.ConfigFilename))
c.cmd.Printf("The configuration file %q will be automatically updated with the IAM values and zone/region information.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
}
iamConfig := c.providerCreator.getIAMConfigOptions()
iamConfig.TFWorkspace = constants.TerraformIAMWorkingDir
iamConfig.TFLogLevel = c.flags.tfLogLevel
c.spinner.Start("Creating", false)
iamFile, err := c.creator.Create(ctx, c.provider, c.iamConfig)
iamFile, err := c.creator.Create(ctx, c.provider, iamConfig)
c.spinner.Stop()
if err != nil {
return err
@ -262,321 +161,47 @@ func (c *iamCreator) create(ctx context.Context) error {
return err
}
if flags.updateConfig {
c.log.Debugf("Writing IAM configuration to %s", c.pf.PrefixPrintablePath(constants.ConfigFilename))
c.providerCreator.writeOutputValuesToConfig(&conf, flags, iamFile)
if c.flags.updateConfig {
c.log.Debugf("Writing IAM configuration to %s", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
c.providerCreator.writeOutputValuesToConfig(&conf, iamFile)
if err := c.fileHandler.WriteYAML(constants.ConfigFilename, conf, file.OptOverwrite); err != nil {
return err
}
c.cmd.Printf("Your IAM configuration was created and filled into %s successfully.\n", c.pf.PrefixPrintablePath(constants.ConfigFilename))
c.cmd.Printf("Your IAM configuration was created and filled into %s successfully.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
return nil
}
c.providerCreator.printOutputValues(c.cmd, flags, iamFile)
c.providerCreator.printOutputValues(c.cmd, iamFile)
c.cmd.Println("Your IAM configuration was created successfully. Please fill the above values into your configuration file.")
return nil
}
// parseFlagsAndSetupConfig parses the flags of the iam create command and fills the values into the IAM config (output values of the command).
func (c *iamCreator) parseFlagsAndSetupConfig() (iamFlags, error) {
workDir, err := c.cmd.Flags().GetString("workspace")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing config string: %w", err)
}
c.pf = pathprefix.New(workDir)
yesFlag, err := c.cmd.Flags().GetBool("yes")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
updateConfig, err := c.cmd.Flags().GetBool("update-config")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing update-config bool: %w", err)
}
flags := iamFlags{
yesFlag: yesFlag,
updateConfig: updateConfig,
}
flags, err = c.providerCreator.parseFlagsAndSetupConfig(c.cmd, flags, c.iamConfig)
if err != nil {
return iamFlags{}, fmt.Errorf("parsing provider-specific value: %w", err)
}
return flags, nil
}
// checkWorkingDir checks if the current working directory already contains a Terraform dir.
func (c *iamCreator) checkWorkingDir() error {
if _, err := c.fileHandler.Stat(constants.TerraformIAMWorkingDir); err == nil {
return fmt.Errorf("the current working directory already contains the Terraform workspace directory %q. Please run the command in a different directory or destroy the existing workspace", c.pf.PrefixPrintablePath(constants.TerraformIAMWorkingDir))
return fmt.Errorf(
"the current working directory already contains the Terraform workspace directory %q. Please run the command in a different directory or destroy the existing workspace",
c.flags.pathPrefixer.PrefixPrintablePath(constants.TerraformIAMWorkingDir),
)
}
return nil
}
// iamFlags contains the parsed flags of the iam create command, including the parsed flags of the selected cloud provider.
type iamFlags struct {
aws awsFlags
azure azureFlags
gcp gcpFlags
yesFlag bool
updateConfig bool
}
// awsFlags contains the parsed flags of the iam create aws command.
type awsFlags struct {
prefix string
region string
zone string
}
// azureFlags contains the parsed flags of the iam create azure command.
type azureFlags struct {
region string
resourceGroup string
servicePrincipal string
}
// gcpFlags contains the parsed flags of the iam create gcp command.
type gcpFlags struct {
serviceAccountID string
zone string
region string
projectID string
}
// providerIAMCreator is an interface for the IAM actions of different cloud providers.
type providerIAMCreator interface {
// printConfirmValues prints the values that will be created on the cloud provider and need to be confirmed by the user.
printConfirmValues(cmd *cobra.Command, flags iamFlags)
printConfirmValues(cmd *cobra.Command)
// printOutputValues prints the values that were created on the cloud provider.
printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput)
printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput)
// writeOutputValuesToConfig writes the output values of the IAM creation to the constellation config file.
writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput)
// parseFlagsAndSetupConfig parses the provider-specific flags and fills the values into the IAM config (output values of the command).
parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error)
writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput)
// getIAMConfigOptions sets up the IAM values required to create the IAM configuration.
getIAMConfigOptions() *cloudcmd.IAMConfigOptions
// parseAndWriteIDFile parses the GCP service account key and writes it to a keyfile. It is only implemented for GCP.
parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error
}
// awsIAMCreator implements the providerIAMCreator interface for AWS.
type awsIAMCreator struct{}
func (c *awsIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
prefix, err := cmd.Flags().GetString("prefix")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing prefix string: %w", err)
}
if len(prefix) > 36 {
return iamFlags{}, fmt.Errorf("prefix must be 36 characters or less")
}
zone, err := cmd.Flags().GetString("zone")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing zone string: %w", err)
}
if !config.ValidateAWSZone(zone) {
return iamFlags{}, fmt.Errorf("invalid AWS zone. To find a valid zone, please refer to our docs and https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-availability-zones")
}
// Infer region from zone.
region := zone[:len(zone)-1]
if !config.ValidateAWSRegion(region) {
return iamFlags{}, fmt.Errorf("invalid AWS region: %s", region)
}
flags.aws = awsFlags{
prefix: prefix,
zone: zone,
region: region,
}
// Setup IAM config.
iamConfig.AWS = cloudcmd.AWSIAMConfig{
Region: flags.aws.region,
Prefix: flags.aws.prefix,
}
return flags, nil
}
func (c *awsIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Region:\t\t%s\n", flags.aws.region)
cmd.Printf("Name Prefix:\t%s\n\n", flags.aws.prefix)
}
func (c *awsIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput) {
cmd.Printf("region:\t\t\t%s\n", flags.aws.region)
cmd.Printf("zone:\t\t\t%s\n", flags.aws.zone)
cmd.Printf("iamProfileControlPlane:\t%s\n", iamFile.AWSOutput.ControlPlaneInstanceProfile)
cmd.Printf("iamProfileWorkerNodes:\t%s\n\n", iamFile.AWSOutput.WorkerNodeInstanceProfile)
}
func (c *awsIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput) {
conf.Provider.AWS.Region = flags.aws.region
conf.Provider.AWS.Zone = flags.aws.zone
conf.Provider.AWS.IAMProfileControlPlane = iamFile.AWSOutput.ControlPlaneInstanceProfile
conf.Provider.AWS.IAMProfileWorkerNodes = iamFile.AWSOutput.WorkerNodeInstanceProfile
for groupName, group := range conf.NodeGroups {
group.Zone = flags.aws.zone
conf.NodeGroups[groupName] = group
}
}
func (c *awsIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
// azureIAMCreator implements the providerIAMCreator interface for Azure.
type azureIAMCreator struct{}
func (c *azureIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
region, err := cmd.Flags().GetString("region")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing region string: %w", err)
}
resourceGroup, err := cmd.Flags().GetString("resourceGroup")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing resourceGroup string: %w", err)
}
servicePrincipal, err := cmd.Flags().GetString("servicePrincipal")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing servicePrincipal string: %w", err)
}
flags.azure = azureFlags{
region: region,
resourceGroup: resourceGroup,
servicePrincipal: servicePrincipal,
}
// Setup IAM config.
iamConfig.Azure = cloudcmd.AzureIAMConfig{
Region: flags.azure.region,
ResourceGroup: flags.azure.resourceGroup,
ServicePrincipal: flags.azure.servicePrincipal,
}
return flags, nil
}
func (c *azureIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Region:\t\t\t%s\n", flags.azure.region)
cmd.Printf("Resource Group:\t\t%s\n", flags.azure.resourceGroup)
cmd.Printf("Service Principal:\t%s\n\n", flags.azure.servicePrincipal)
}
func (c *azureIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, iamFile cloudcmd.IAMOutput) {
cmd.Printf("subscription:\t\t%s\n", iamFile.AzureOutput.SubscriptionID)
cmd.Printf("tenant:\t\t\t%s\n", iamFile.AzureOutput.TenantID)
cmd.Printf("location:\t\t%s\n", flags.azure.region)
cmd.Printf("resourceGroup:\t\t%s\n", flags.azure.resourceGroup)
cmd.Printf("userAssignedIdentity:\t%s\n", iamFile.AzureOutput.UAMIID)
}
func (c *azureIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, iamFile cloudcmd.IAMOutput) {
conf.Provider.Azure.SubscriptionID = iamFile.AzureOutput.SubscriptionID
conf.Provider.Azure.TenantID = iamFile.AzureOutput.TenantID
conf.Provider.Azure.Location = flags.azure.region
conf.Provider.Azure.ResourceGroup = flags.azure.resourceGroup
conf.Provider.Azure.UserAssignedIdentity = iamFile.AzureOutput.UAMIID
}
func (c *azureIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
// gcpIAMCreator implements the providerIAMCreator interface for GCP.
type gcpIAMCreator struct {
pf pathprefix.PathPrefixer
}
func (c *gcpIAMCreator) parseFlagsAndSetupConfig(cmd *cobra.Command, flags iamFlags, iamConfig *cloudcmd.IAMConfigOptions) (iamFlags, error) {
zone, err := cmd.Flags().GetString("zone")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing zone string: %w", err)
}
if !zoneRegex.MatchString(zone) {
return iamFlags{}, fmt.Errorf("invalid zone string: %s", zone)
}
// Infer region from zone.
zoneParts := strings.Split(zone, "-")
region := fmt.Sprintf("%s-%s", zoneParts[0], zoneParts[1])
if !regionRegex.MatchString(region) {
return iamFlags{}, fmt.Errorf("invalid region string: %s", region)
}
projectID, err := cmd.Flags().GetString("projectID")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing projectID string: %w", err)
}
if !gcpIDRegex.MatchString(projectID) {
return iamFlags{}, fmt.Errorf("projectID %q doesn't match %s", projectID, gcpIDRegex)
}
serviceAccID, err := cmd.Flags().GetString("serviceAccountID")
if err != nil {
return iamFlags{}, fmt.Errorf("parsing serviceAccountID string: %w", err)
}
if !gcpIDRegex.MatchString(serviceAccID) {
return iamFlags{}, fmt.Errorf("serviceAccountID %q doesn't match %s", serviceAccID, gcpIDRegex)
}
flags.gcp = gcpFlags{
zone: zone,
region: region,
projectID: projectID,
serviceAccountID: serviceAccID,
}
// Setup IAM config.
iamConfig.GCP = cloudcmd.GCPIAMConfig{
Zone: flags.gcp.zone,
Region: flags.gcp.region,
ProjectID: flags.gcp.projectID,
ServiceAccountID: flags.gcp.serviceAccountID,
}
return flags, nil
}
func (c *gcpIAMCreator) printConfirmValues(cmd *cobra.Command, flags iamFlags) {
cmd.Printf("Project ID:\t\t%s\n", flags.gcp.projectID)
cmd.Printf("Service Account ID:\t%s\n", flags.gcp.serviceAccountID)
cmd.Printf("Region:\t\t\t%s\n", flags.gcp.region)
cmd.Printf("Zone:\t\t\t%s\n\n", flags.gcp.zone)
}
func (c *gcpIAMCreator) printOutputValues(cmd *cobra.Command, flags iamFlags, _ cloudcmd.IAMOutput) {
cmd.Printf("projectID:\t\t%s\n", flags.gcp.projectID)
cmd.Printf("region:\t\t\t%s\n", flags.gcp.region)
cmd.Printf("zone:\t\t\t%s\n", flags.gcp.zone)
cmd.Printf("serviceAccountKeyPath:\t%s\n\n", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
}
func (c *gcpIAMCreator) writeOutputValuesToConfig(conf *config.Config, flags iamFlags, _ cloudcmd.IAMOutput) {
conf.Provider.GCP.Project = flags.gcp.projectID
conf.Provider.GCP.ServiceAccountKeyPath = constants.GCPServiceAccountKeyFilename // File was created in workspace, so only the filename is needed.
conf.Provider.GCP.Region = flags.gcp.region
conf.Provider.GCP.Zone = flags.gcp.zone
for groupName, group := range conf.NodeGroups {
group.Zone = flags.gcp.zone
conf.NodeGroups[groupName] = group
}
}
func (c *gcpIAMCreator) parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error {
// GCP needs to write the service account key to a file.
tmpOut, err := parseIDFile(iamFile.GCPOutput.ServiceAccountKey)
if err != nil {
return err
}
return fileHandler.WriteJSON(constants.GCPServiceAccountKeyFilename, tmpOut, file.OptNone)
validateConfigWithFlagCompatibility(config.Config) error
}
// parseIDFile parses the given base64 encoded JSON string of the GCP service account key and returns a map.
@ -594,30 +219,17 @@ func parseIDFile(serviceAccountKeyBase64 string) (map[string]string, error) {
}
// validateConfigWithFlagCompatibility checks if the config is compatible with the flags.
func validateConfigWithFlagCompatibility(iamProvider cloudprovider.Provider, cfg config.Config, flags iamFlags) error {
func validateConfigWithFlagCompatibility(iamProvider cloudprovider.Provider, cfg config.Config, zone string) error {
if !cfg.HasProvider(iamProvider) {
return fmt.Errorf("cloud provider from the the configuration file differs from the one provided via the command %q", iamProvider)
}
return checkIfCfgZoneAndFlagZoneDiffer(iamProvider, flags, cfg)
return checkIfCfgZoneAndFlagZoneDiffer(zone, cfg)
}
func checkIfCfgZoneAndFlagZoneDiffer(iamProvider cloudprovider.Provider, flags iamFlags, cfg config.Config) error {
flagZone := flagZoneOrAzRegion(iamProvider, flags)
func checkIfCfgZoneAndFlagZoneDiffer(zone string, cfg config.Config) error {
configZone := cfg.GetZone()
if configZone != "" && flagZone != configZone {
return fmt.Errorf("zone/region from the configuration file %q differs from the one provided via flags %q", configZone, flagZone)
if configZone != "" && zone != configZone {
return fmt.Errorf("zone/region from the configuration file %q differs from the one provided via flags %q", configZone, zone)
}
return nil
}
func flagZoneOrAzRegion(provider cloudprovider.Provider, flags iamFlags) string {
switch provider {
case cloudprovider.AWS:
return flags.aws.zone
case cloudprovider.Azure:
return flags.azure.region
case cloudprovider.GCP:
return flags.gcp.zone
}
return ""
}

View File

@ -82,7 +82,6 @@ func TestIAMCreateAWS(t *testing.T) {
testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator
provider cloudprovider.Provider
zoneFlag string
prefixFlag string
yesFlag bool
@ -96,26 +95,14 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
yesFlag: true,
existingConfigFiles: []string{constants.ConfigFilename},
},
"iam create aws fails when --zone has no availability zone": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1",
prefixFlag: "test",
yesFlag: true,
existingConfigFiles: []string{constants.ConfigFilename},
wantErr: true,
},
"iam create aws --update-config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
yesFlag: true,
@ -130,7 +117,6 @@ func TestIAMCreateAWS(t *testing.T) {
return *cfg
}()),
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1a",
prefixFlag: "test",
yesFlag: true,
@ -141,7 +127,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws --update-config fails when config has different provider": {
setupFs: createFSWithConfig(*createConfig(cloudprovider.GCP)),
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-1a",
prefixFlag: "test",
yesFlag: true,
@ -152,7 +137,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws no config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
yesFlag: true,
@ -160,7 +144,6 @@ func TestIAMCreateAWS(t *testing.T) {
"iam create aws existing terraform dir": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
yesFlag: true,
@ -170,7 +153,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
stdin: "yes\n",
@ -178,7 +160,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive update config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
stdin: "yes\n",
@ -188,7 +169,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive abort": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
stdin: "no\n",
@ -197,7 +177,6 @@ func TestIAMCreateAWS(t *testing.T) {
"interactive update config abort": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
stdin: "no\n",
@ -205,19 +184,9 @@ func TestIAMCreateAWS(t *testing.T) {
wantAbort: true,
existingConfigFiles: []string{constants.ConfigFilename},
},
"invalid zone": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-west",
prefixFlag: "test",
yesFlag: true,
wantErr: true,
},
"unwritable fs": {
setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.AWS,
zoneFlag: "us-east-2a",
prefixFlag: "test",
yesFlag: true,
@ -236,37 +205,26 @@ func TestIAMCreateAWS(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.zoneFlag != "" {
require.NoError(cmd.Flags().Set("zone", tc.zoneFlag))
}
if tc.prefixFlag != "" {
require.NoError(cmd.Flags().Set("prefix", tc.prefixFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.AWS, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider,
providerCreator: &awsIAMCreator{},
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: cloudprovider.AWS,
flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &awsIAMCreator{
flags: awsIAMCreateFlags{
zone: tc.zoneFlag,
prefix: tc.prefixFlag,
},
},
}
err := iamCreator.create(cmd.Context())
@ -315,7 +273,6 @@ func TestIAMCreateAzure(t *testing.T) {
testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator
provider cloudprovider.Provider
regionFlag string
servicePrincipalFlag string
resourceGroupFlag string
@ -330,7 +287,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -339,7 +295,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure with existing config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -349,7 +304,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure --update-config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -360,7 +314,6 @@ func TestIAMCreateAzure(t *testing.T) {
"iam create azure existing terraform dir": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -371,7 +324,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -380,7 +332,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive update config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -391,7 +342,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive abort": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -401,7 +351,6 @@ func TestIAMCreateAzure(t *testing.T) {
"interactive update config abort": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -413,7 +362,6 @@ func TestIAMCreateAzure(t *testing.T) {
"unwritable fs": {
setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.Azure,
regionFlag: "westus",
servicePrincipalFlag: "constell-test-sp",
resourceGroupFlag: "constell-test-rg",
@ -433,40 +381,27 @@ func TestIAMCreateAzure(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.regionFlag != "" {
require.NoError(cmd.Flags().Set("region", tc.regionFlag))
}
if tc.resourceGroupFlag != "" {
require.NoError(cmd.Flags().Set("resourceGroup", tc.resourceGroupFlag))
}
if tc.servicePrincipalFlag != "" {
require.NoError(cmd.Flags().Set("servicePrincipal", tc.servicePrincipalFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.Azure, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider,
providerCreator: &azureIAMCreator{},
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: cloudprovider.Azure,
flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &azureIAMCreator{
flags: azureIAMCreateFlags{
region: tc.regionFlag,
resourceGroup: tc.resourceGroupFlag,
servicePrincipal: tc.servicePrincipalFlag,
},
},
}
err := iamCreator.create(cmd.Context())
@ -519,7 +454,6 @@ func TestIAMCreateGCP(t *testing.T) {
testCases := map[string]struct {
setupFs func(require *require.Assertions, provider cloudprovider.Provider, existingConfigFiles []string, existingDirs []string) afero.Fs
creator *stubIAMCreator
provider cloudprovider.Provider
zoneFlag string
serviceAccountIDFlag string
projectIDFlag string
@ -534,7 +468,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -543,7 +476,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp with existing config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -553,7 +485,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp --update-config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -564,7 +495,6 @@ func TestIAMCreateGCP(t *testing.T) {
"iam create gcp existing terraform dir": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -573,18 +503,9 @@ func TestIAMCreateGCP(t *testing.T) {
yesFlag: true,
wantErr: true,
},
"iam create gcp invalid flags": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "-a",
yesFlag: true,
wantErr: true,
},
"iam create gcp invalid b64": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: invalidIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -594,7 +515,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -603,7 +523,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive update config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -614,7 +533,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive abort": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -624,7 +542,6 @@ func TestIAMCreateGCP(t *testing.T) {
"interactive abort update config": {
setupFs: defaultFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -636,7 +553,6 @@ func TestIAMCreateGCP(t *testing.T) {
"unwritable fs": {
setupFs: readOnlyFs,
creator: &stubIAMCreator{id: validIAMIDFile},
provider: cloudprovider.GCP,
zoneFlag: "europe-west1-a",
serviceAccountIDFlag: "constell-test",
projectIDFlag: "constell-1234",
@ -656,40 +572,27 @@ func TestIAMCreateGCP(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("update-config", false, "")
cmd.Flags().Bool("yes", false, "")
cmd.Flags().String("name", "constell", "")
cmd.Flags().String("tf-log", "NONE", "")
if tc.zoneFlag != "" {
require.NoError(cmd.Flags().Set("zone", tc.zoneFlag))
}
if tc.serviceAccountIDFlag != "" {
require.NoError(cmd.Flags().Set("serviceAccountID", tc.serviceAccountIDFlag))
}
if tc.projectIDFlag != "" {
require.NoError(cmd.Flags().Set("projectID", tc.projectIDFlag))
}
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.updateConfigFlag {
require.NoError(cmd.Flags().Set("update-config", "true"))
}
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider, tc.existingConfigFiles, tc.existingDirs))
fileHandler := file.NewHandler(tc.setupFs(require, cloudprovider.GCP, tc.existingConfigFiles, tc.existingDirs))
iamCreator := &iamCreator{
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: tc.provider,
providerCreator: &gcpIAMCreator{},
cmd: cmd,
log: logger.NewTest(t),
spinner: &nopSpinner{},
creator: tc.creator,
fileHandler: fileHandler,
iamConfig: &cloudcmd.IAMConfigOptions{},
provider: cloudprovider.GCP,
flags: iamCreateFlags{
yes: tc.yesFlag,
updateConfig: tc.updateConfigFlag,
},
providerCreator: &gcpIAMCreator{
flags: gcpIAMCreateFlags{
zone: tc.zoneFlag,
serviceAccountID: tc.serviceAccountIDFlag,
projectID: tc.projectIDFlag,
},
},
}
err := iamCreator.create(cmd.Context())
@ -724,7 +627,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
testCases := map[string]struct {
iamProvider cloudprovider.Provider
cfg config.Config
flags iamFlags
zone string
wantErr bool
}{
"AWS valid when cfg.zone == flag.zone": {
@ -734,20 +637,12 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.AWS.Zone = "europe-west-1a"
return *cfg
}(),
flags: iamFlags{
aws: awsFlags{
zone: "europe-west-1a",
},
},
zone: "europe-west-1a",
},
"AWS valid when cfg.zone not set": {
iamProvider: cloudprovider.AWS,
cfg: *createConfig(cloudprovider.AWS),
flags: iamFlags{
aws: awsFlags{
zone: "europe-west-1a",
},
},
zone: "europe-west-1a",
},
"GCP invalid when cfg.zone != flag.zone": {
iamProvider: cloudprovider.GCP,
@ -756,11 +651,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.GCP.Zone = "europe-west-1a"
return *cfg
}(),
flags: iamFlags{
aws: awsFlags{
zone: "us-west-1a",
},
},
zone: "us-west-1a",
wantErr: true,
},
"Azure invalid when cfg.zone != flag.zone": {
@ -770,11 +661,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
cfg.Provider.Azure.Location = "europe-west-1a"
return *cfg
}(),
flags: iamFlags{
aws: awsFlags{
zone: "us-west-1a",
},
},
zone: "us-west-1a",
wantErr: true,
},
"GCP invalid when cfg.provider different from iam provider": {
@ -786,7 +673,7 @@ func TestValidateConfigWithFlagCompatibility(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := validateConfigWithFlagCompatibility(tc.iamProvider, tc.cfg, tc.flags)
err := validateConfigWithFlagCompatibility(tc.iamProvider, tc.cfg, tc.zone)
if tc.wantErr {
assert.Error(err)
return

View File

@ -0,0 +1,121 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// newIAMCreateAWSCmd returns a new cobra.Command for the iam create aws command.
func newIAMCreateAWSCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "aws",
Short: "Create IAM configuration on AWS for your Constellation cluster",
Long: "Create IAM configuration on AWS for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateAWS,
}
cmd.Flags().String("prefix", "", "name prefix for all resources (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "prefix"))
cmd.Flags().String("zone", "", "AWS availability zone the resources will be created in, e.g., us-east-2a (required)\n"+
"See the Constellation docs for a list of currently supported regions.")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
return cmd
}
func runIAMCreateAWS(cmd *cobra.Command, _ []string) error {
creator := &awsIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.AWS)
}
// awsIAMCreateFlags contains the parsed flags of the iam create aws command.
type awsIAMCreateFlags struct {
prefix string
region string
zone string
}
func (f *awsIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
f.prefix, err = flags.GetString("prefix")
if err != nil {
return fmt.Errorf("getting 'prefix' flag: %w", err)
}
if len(f.prefix) > 36 {
return errors.New("prefix must be 36 characters or less")
}
f.zone, err = flags.GetString("zone")
if err != nil {
return fmt.Errorf("getting 'zone' flag: %w", err)
}
if !config.ValidateAWSZone(f.zone) {
return errors.New("invalid AWS zone. To find a valid zone, please refer to our docs and https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html#concepts-availability-zones")
}
// Infer region from zone.
f.region = f.zone[:len(f.zone)-1]
if !config.ValidateAWSRegion(f.region) {
return fmt.Errorf("invalid AWS region: %s", f.region)
}
return nil
}
// awsIAMCreator implements the providerIAMCreator interface for AWS.
type awsIAMCreator struct {
flags awsIAMCreateFlags
}
func (c *awsIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
AWS: cloudcmd.AWSIAMConfig{
Region: c.flags.region,
Prefix: c.flags.prefix,
},
}
}
func (c *awsIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Region:\t\t%s\n", c.flags.region)
cmd.Printf("Name Prefix:\t%s\n\n", c.flags.prefix)
}
func (c *awsIAMCreator) printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput) {
cmd.Printf("region:\t\t\t%s\n", c.flags.region)
cmd.Printf("zone:\t\t\t%s\n", c.flags.zone)
cmd.Printf("iamProfileControlPlane:\t%s\n", iamFile.AWSOutput.ControlPlaneInstanceProfile)
cmd.Printf("iamProfileWorkerNodes:\t%s\n\n", iamFile.AWSOutput.WorkerNodeInstanceProfile)
}
func (c *awsIAMCreator) writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput) {
conf.Provider.AWS.Region = c.flags.region
conf.Provider.AWS.Zone = c.flags.zone
conf.Provider.AWS.IAMProfileControlPlane = iamFile.AWSOutput.ControlPlaneInstanceProfile
conf.Provider.AWS.IAMProfileWorkerNodes = iamFile.AWSOutput.WorkerNodeInstanceProfile
for groupName, group := range conf.NodeGroups {
group.Zone = c.flags.zone
conf.NodeGroups[groupName] = group
}
}
func (c *awsIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
func (c *awsIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.AWS, conf, c.flags.zone)
}

View File

@ -0,0 +1,113 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// newIAMCreateAzureCmd returns a new cobra.Command for the iam create azure command.
func newIAMCreateAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure",
Short: "Create IAM configuration on Microsoft Azure for your Constellation cluster",
Long: "Create IAM configuration on Microsoft Azure for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateAzure,
}
cmd.Flags().String("resourceGroup", "", "name prefix of the two resource groups your cluster / IAM resources will be created in (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "resourceGroup"))
cmd.Flags().String("region", "", "region the resources will be created in, e.g., westus (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "region"))
cmd.Flags().String("servicePrincipal", "", "name of the service principal that will be created (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "servicePrincipal"))
return cmd
}
func runIAMCreateAzure(cmd *cobra.Command, _ []string) error {
creator := &azureIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.Azure)
}
// azureIAMCreateFlags contains the parsed flags of the iam create azure command.
type azureIAMCreateFlags struct {
region string
resourceGroup string
servicePrincipal string
}
func (f *azureIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
f.region, err = flags.GetString("region")
if err != nil {
return fmt.Errorf("getting 'region' flag: %w", err)
}
f.resourceGroup, err = flags.GetString("resourceGroup")
if err != nil {
return fmt.Errorf("getting 'resourceGroup' flag: %w", err)
}
f.servicePrincipal, err = flags.GetString("servicePrincipal")
if err != nil {
return fmt.Errorf("getting 'servicePrincipal' flag: %w", err)
}
return nil
}
// azureIAMCreator implements the providerIAMCreator interface for Azure.
type azureIAMCreator struct {
flags azureIAMCreateFlags
}
func (c *azureIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
Azure: cloudcmd.AzureIAMConfig{
Region: c.flags.region,
ResourceGroup: c.flags.resourceGroup,
ServicePrincipal: c.flags.servicePrincipal,
},
}
}
func (c *azureIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Region:\t\t\t%s\n", c.flags.region)
cmd.Printf("Resource Group:\t\t%s\n", c.flags.resourceGroup)
cmd.Printf("Service Principal:\t%s\n\n", c.flags.servicePrincipal)
}
func (c *azureIAMCreator) printOutputValues(cmd *cobra.Command, iamFile cloudcmd.IAMOutput) {
cmd.Printf("subscription:\t\t%s\n", iamFile.AzureOutput.SubscriptionID)
cmd.Printf("tenant:\t\t\t%s\n", iamFile.AzureOutput.TenantID)
cmd.Printf("location:\t\t%s\n", c.flags.region)
cmd.Printf("resourceGroup:\t\t%s\n", c.flags.resourceGroup)
cmd.Printf("userAssignedIdentity:\t%s\n", iamFile.AzureOutput.UAMIID)
}
func (c *azureIAMCreator) writeOutputValuesToConfig(conf *config.Config, iamFile cloudcmd.IAMOutput) {
conf.Provider.Azure.SubscriptionID = iamFile.AzureOutput.SubscriptionID
conf.Provider.Azure.TenantID = iamFile.AzureOutput.TenantID
conf.Provider.Azure.Location = c.flags.region
conf.Provider.Azure.ResourceGroup = c.flags.resourceGroup
conf.Provider.Azure.UserAssignedIdentity = iamFile.AzureOutput.UAMIID
}
func (c *azureIAMCreator) parseAndWriteIDFile(_ cloudcmd.IAMOutput, _ file.Handler) error {
return nil
}
func (c *azureIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.Azure, conf, c.flags.region)
}

View File

@ -0,0 +1,153 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"strings"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// NewIAMCreateGCPCmd returns a new cobra.Command for the iam create gcp command.
func newIAMCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create IAM configuration on GCP for your Constellation cluster",
Long: "Create IAM configuration on GCP for your Constellation cluster.",
Args: cobra.ExactArgs(0),
RunE: runIAMCreateGCP,
}
cmd.Flags().String("zone", "", "GCP zone the cluster will be deployed in (required)\n"+
"Find a list of available zones here: https://cloud.google.com/compute/docs/regions-zones#available")
must(cobra.MarkFlagRequired(cmd.Flags(), "zone"))
cmd.Flags().String("serviceAccountID", "", "ID for the service account that will be created (required)\n"+
"Must be 6 to 30 lowercase letters, digits, or hyphens.")
must(cobra.MarkFlagRequired(cmd.Flags(), "serviceAccountID"))
cmd.Flags().String("projectID", "", "ID of the GCP project the configuration will be created in (required)\n"+
"Find it on the welcome screen of your project: https://console.cloud.google.com/welcome")
must(cobra.MarkFlagRequired(cmd.Flags(), "projectID"))
return cmd
}
func runIAMCreateGCP(cmd *cobra.Command, _ []string) error {
creator := &gcpIAMCreator{}
if err := creator.flags.parse(cmd.Flags()); err != nil {
return err
}
return runIAMCreate(cmd, creator, cloudprovider.GCP)
}
// gcpIAMCreateFlags contains the parsed flags of the iam create gcp command.
type gcpIAMCreateFlags struct {
rootFlags
serviceAccountID string
zone string
region string
projectID string
}
func (f *gcpIAMCreateFlags) parse(flags *pflag.FlagSet) error {
var err error
if err = f.rootFlags.parse(flags); err != nil {
return err
}
f.zone, err = flags.GetString("zone")
if err != nil {
return fmt.Errorf("getting 'zone' flag: %w", err)
}
if !zoneRegex.MatchString(f.zone) {
return fmt.Errorf("invalid zone string: %s", f.zone)
}
// Infer region from zone.
zoneParts := strings.Split(f.zone, "-")
f.region = fmt.Sprintf("%s-%s", zoneParts[0], zoneParts[1])
if !regionRegex.MatchString(f.region) {
return fmt.Errorf("invalid region string: %s", f.region)
}
f.projectID, err = flags.GetString("projectID")
if err != nil {
return fmt.Errorf("getting 'projectID' flag: %w", err)
}
if !gcpIDRegex.MatchString(f.projectID) {
return fmt.Errorf("projectID %q doesn't match %s", f.projectID, gcpIDRegex)
}
f.serviceAccountID, err = flags.GetString("serviceAccountID")
if err != nil {
return fmt.Errorf("getting 'serviceAccountID' flag: %w", err)
}
if !gcpIDRegex.MatchString(f.serviceAccountID) {
return fmt.Errorf("serviceAccountID %q doesn't match %s", f.serviceAccountID, gcpIDRegex)
}
return nil
}
// gcpIAMCreator implements the providerIAMCreator interface for GCP.
type gcpIAMCreator struct {
flags gcpIAMCreateFlags
}
func (c *gcpIAMCreator) getIAMConfigOptions() *cloudcmd.IAMConfigOptions {
return &cloudcmd.IAMConfigOptions{
GCP: cloudcmd.GCPIAMConfig{
Zone: c.flags.zone,
Region: c.flags.region,
ProjectID: c.flags.projectID,
ServiceAccountID: c.flags.serviceAccountID,
},
}
}
func (c *gcpIAMCreator) printConfirmValues(cmd *cobra.Command) {
cmd.Printf("Project ID:\t\t%s\n", c.flags.projectID)
cmd.Printf("Service Account ID:\t%s\n", c.flags.serviceAccountID)
cmd.Printf("Region:\t\t\t%s\n", c.flags.region)
cmd.Printf("Zone:\t\t\t%s\n\n", c.flags.zone)
}
func (c *gcpIAMCreator) printOutputValues(cmd *cobra.Command, _ cloudcmd.IAMOutput) {
cmd.Printf("projectID:\t\t%s\n", c.flags.projectID)
cmd.Printf("region:\t\t\t%s\n", c.flags.region)
cmd.Printf("zone:\t\t\t%s\n", c.flags.zone)
cmd.Printf("serviceAccountKeyPath:\t%s\n\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
}
func (c *gcpIAMCreator) writeOutputValuesToConfig(conf *config.Config, _ cloudcmd.IAMOutput) {
conf.Provider.GCP.Project = c.flags.projectID
conf.Provider.GCP.ServiceAccountKeyPath = constants.GCPServiceAccountKeyFilename // File was created in workspace, so only the filename is needed.
conf.Provider.GCP.Region = c.flags.region
conf.Provider.GCP.Zone = c.flags.zone
for groupName, group := range conf.NodeGroups {
group.Zone = c.flags.zone
conf.NodeGroups[groupName] = group
}
}
func (c *gcpIAMCreator) parseAndWriteIDFile(iamFile cloudcmd.IAMOutput, fileHandler file.Handler) error {
// GCP needs to write the service account key to a file.
tmpOut, err := parseIDFile(iamFile.GCPOutput.ServiceAccountKey)
if err != nil {
return err
}
return fileHandler.WriteJSON(constants.GCPServiceAccountKeyFilename, tmpOut, file.OptNone)
}
func (c *gcpIAMCreator) validateConfigWithFlagCompatibility(conf config.Config) error {
return validateConfigWithFlagCompatibility(cloudprovider.GCP, conf, c.flags.zone)
}

View File

@ -11,13 +11,12 @@ import (
"os"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// NewIAMDestroyCmd returns a new cobra.Command for the iam destroy subcommand.
@ -35,6 +34,25 @@ func newIAMDestroyCmd() *cobra.Command {
return cmd
}
type iamDestroyFlags struct {
rootFlags
yes bool
}
func (f *iamDestroyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
func runIAMDestroy(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
@ -46,51 +64,47 @@ func runIAMDestroy(cmd *cobra.Command, _ []string) error {
fsHandler := file.NewHandler(afero.NewOsFs())
c := &destroyCmd{log: log}
if err := c.flags.parse(cmd.Flags()); err != nil {
return err
}
return c.iamDestroy(cmd, spinner, destroyer, fsHandler)
}
type destroyCmd struct {
log debugLog
pf pathprefix.PathPrefixer
log debugLog
flags iamDestroyFlags
}
func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destroyer iamDestroyer, fsHandler file.Handler) error {
flags, err := c.parseDestroyFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
// check if there is a possibility that the cluster is still running by looking out for specific files
c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if _, err := fsHandler.Stat(constants.AdminConfFilename); !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
}
// check if there is a possibility that the cluster is still running by looking out for specific files
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.AdminConfFilename))
_, err = fsHandler.Stat(constants.AdminConfFilename)
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.pf.PrefixPrintablePath(constants.AdminConfFilename))
}
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.StateFilename))
_, err = fsHandler.Stat(constants.StateFilename)
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.pf.PrefixPrintablePath(constants.StateFilename))
c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if _, err := fsHandler.Stat(constants.StateFilename); !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("file %q still exists, please make sure to terminate your cluster before destroying your IAM configuration", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
}
gcpFileExists := false
c.log.Debugf("Checking if %q exists", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
_, err = fsHandler.Stat(constants.GCPServiceAccountKeyFilename)
if err != nil {
c.log.Debugf("Checking if %q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
if _, err := fsHandler.Stat(constants.GCPServiceAccountKeyFilename); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return err
}
} else {
c.log.Debugf("%q exists", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
c.log.Debugf("%q exists", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
gcpFileExists = true
}
if !flags.yes {
if !c.flags.yes {
// Confirmation
confirmString := "Do you really want to destroy your IAM configuration? Note that this will remove all resources in the resource group."
if gcpFileExists {
confirmString += fmt.Sprintf("\nThis will also delete %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
confirmString += fmt.Sprintf("\nThis will also delete %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
}
ok, err := askToConfirm(cmd, confirmString)
if err != nil {
@ -103,7 +117,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
}
if gcpFileExists {
c.log.Debugf("Starting to delete %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
c.log.Debugf("Starting to delete %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
proceed, err := c.deleteGCPServiceAccountKeyFile(cmd, destroyer, fsHandler)
if err != nil {
return err
@ -118,7 +132,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
spinner.Start("Destroying IAM configuration", false)
defer spinner.Stop()
if err := destroyer.DestroyIAMConfiguration(cmd.Context(), constants.TerraformIAMWorkingDir, flags.tfLogLevel); err != nil {
if err := destroyer.DestroyIAMConfiguration(cmd.Context(), constants.TerraformIAMWorkingDir, c.flags.tfLogLevel); err != nil {
return fmt.Errorf("destroying IAM configuration: %w", err)
}
@ -130,7 +144,7 @@ func (c *destroyCmd) iamDestroy(cmd *cobra.Command, spinner spinnerInterf, destr
func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroyer iamDestroyer, fsHandler file.Handler) (bool, error) {
var fileSaKey gcpshared.ServiceAccountKey
c.log.Debugf("Parsing %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
c.log.Debugf("Parsing %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
if err := fsHandler.ReadJSON(constants.GCPServiceAccountKeyFilename, &fileSaKey); err != nil {
return false, err
}
@ -143,7 +157,11 @@ func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroye
c.log.Debugf("Checking if keys are the same")
if tfSaKey != fileSaKey {
cmd.Printf("The key in %q don't match up with your Terraform state. %q will not be deleted.\n", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename), c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
cmd.Printf(
"The key in %q don't match up with your Terraform state. %q will not be deleted.\n",
c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename),
c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename),
)
return true, nil
}
@ -151,42 +169,6 @@ func (c *destroyCmd) deleteGCPServiceAccountKeyFile(cmd *cobra.Command, destroye
return false, err
}
c.log.Debugf("Successfully deleted %q", c.pf.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
c.log.Debugf("Successfully deleted %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.GCPServiceAccountKeyFilename))
return true, nil
}
type destroyFlags struct {
yes bool
tfLogLevel terraform.LogLevel
}
// parseDestroyFlags parses the flags of the create command.
func (c *destroyCmd) parseDestroyFlags(cmd *cobra.Command) (destroyFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
c.log.Debugf("Yes flag is %t", yes)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing workspace string: %w", err)
}
c.log.Debugf("Workspace set to %q", workDir)
c.pf = pathprefix.New(workDir)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return destroyFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
c.log.Debugf("Terraform logs will be written into %s at level %s", c.pf.PrefixPrintablePath(constants.TerraformWorkingDir), logLevel.String())
return destroyFlags{
tfLogLevel: logLevel,
yes: yes,
}, nil
}

View File

@ -46,52 +46,52 @@ func TestIAMDestroy(t *testing.T) {
iamDestroyer *stubIAMDestroyer
fh file.Handler
stdin string
yesFlag string
yesFlag bool
wantErr bool
wantDestroyCalled bool
}{
"cluster running admin conf": {
fh: newFsWithAdminConf(),
iamDestroyer: &stubIAMDestroyer{},
yesFlag: "false",
yesFlag: false,
wantErr: true,
},
"cluster running cluster state": {
fh: newFsWithStateFile(),
iamDestroyer: &stubIAMDestroyer{},
yesFlag: "false",
yesFlag: false,
wantErr: true,
},
"file missing abort": {
fh: newFsMissing(),
stdin: "n\n",
yesFlag: "false",
yesFlag: false,
iamDestroyer: &stubIAMDestroyer{},
},
"file missing": {
fh: newFsMissing(),
stdin: "y\n",
yesFlag: "false",
yesFlag: false,
iamDestroyer: &stubIAMDestroyer{},
wantDestroyCalled: true,
},
"file exists abort": {
fh: newFsExists(),
stdin: "n\n",
yesFlag: "false",
yesFlag: false,
iamDestroyer: &stubIAMDestroyer{},
},
"error destroying user": {
fh: newFsMissing(),
stdin: "y\n",
yesFlag: "false",
yesFlag: false,
iamDestroyer: &stubIAMDestroyer{destroyErr: someError},
wantErr: true,
wantDestroyCalled: true,
},
"gcp delete error": {
fh: newFsExists(),
yesFlag: "true",
yesFlag: true,
iamDestroyer: &stubIAMDestroyer{getTfStateKeyErr: someError},
wantErr: true,
},
@ -106,13 +106,9 @@ func TestIAMDestroy(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().String("workspace", "", "")
assert.NoError(cmd.Flags().Set("yes", tc.yesFlag))
c := &destroyCmd{log: logger.NewTest(t)}
c := &destroyCmd{log: logger.NewTest(t), flags: iamDestroyFlags{
yes: tc.yesFlag,
}}
err := c.iamDestroy(cmd, &nopSpinner{}, tc.iamDestroyer, tc.fh)

View File

@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
func newIAMUpgradeCmd() *cobra.Command {
@ -46,17 +47,32 @@ func newIAMUpgradeApplyCmd() *cobra.Command {
return cmd
}
type iamUpgradeApplyFlags struct {
rootFlags
yes bool
}
func (f *iamUpgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
type iamUpgradeApplyCmd struct {
fileHandler file.Handler
log debugLog
configFetcher attestationconfigapi.Fetcher
flags iamUpgradeApplyFlags
}
func runIAMUpgradeApply(cmd *cobra.Command, _ []string) error {
force, err := cmd.Flags().GetBool("force")
if err != nil {
return fmt.Errorf("parsing force argument: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs())
upgradeID := generateUpgradeID(upgradeCmdKindIAM)
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
@ -77,22 +93,20 @@ func runIAMUpgradeApply(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("setting up logger: %w", err)
}
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
i := iamUpgradeApplyCmd{
fileHandler: fileHandler,
log: log,
configFetcher: configFetcher,
}
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
return i.iamUpgradeApply(cmd, iamMigrateCmd, upgradeDir, force, yes)
return i.iamUpgradeApply(cmd, iamMigrateCmd, upgradeDir)
}
func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamUpgrader, upgradeDir string, force, yes bool) error {
conf, err := config.New(i.fileHandler, constants.ConfigFilename, i.configFetcher, force)
func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamUpgrader, upgradeDir string) error {
conf, err := config.New(i.fileHandler, constants.ConfigFilename, i.configFetcher, i.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -109,14 +123,14 @@ func (i iamUpgradeApplyCmd) iamUpgradeApply(cmd *cobra.Command, iamUpgrader iamU
if err != nil {
return fmt.Errorf("planning terraform migrations: %w", err)
}
if !hasDiff && !force {
if !hasDiff && !i.flags.force {
cmd.Println("No IAM migrations necessary.")
return nil
}
// If there are any Terraform migrations to apply, ask for confirmation
cmd.Println("The IAM upgrade requires a migration by applying an updated Terraform template. Please manually review the suggested changes.")
if !yes {
if !i.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the IAM upgrade?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)

View File

@ -132,9 +132,12 @@ func TestIamUpgradeApply(t *testing.T) {
fileHandler: tc.fh,
log: logger.NewTest(t),
configFetcher: tc.configFetcher,
flags: iamUpgradeApplyFlags{
yes: tc.yesFlag,
},
}
err := iamUpgradeApplyCmd.iamUpgradeApply(cmd, tc.iamUpgrader, "", false, tc.yesFlag)
err := iamUpgradeApplyCmd.iamUpgradeApply(cmd, tc.iamUpgrader, "")
if tc.wantErr {
assert.Error(err)
} else {

View File

@ -28,6 +28,7 @@ import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/clientcmd"
@ -36,7 +37,6 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
@ -69,12 +69,45 @@ func NewInitCmd() *cobra.Command {
return cmd
}
// initFlags are flags used by the init command.
type initFlags struct {
rootFlags
conformance bool
helmWaitMode helm.WaitMode
mergeConfigs bool
}
func (f *initFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
type initCmd struct {
log debugLog
merger configMerger
spinner spinnerInterf
fileHandler file.Handler
pf pathprefix.PathPrefixer
flags initFlags
}
func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd {
@ -109,6 +142,11 @@ func runInitialize(cmd *cobra.Command, _ []string) error {
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
i.log.Debugf("Using flags: %+v", i.flags)
fetcher := attestationconfigapi.NewFetcher()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log)
@ -127,13 +165,8 @@ func (i *initCmd) initialize(
newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error),
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error),
) error {
flags, err := i.evalFlagArgs(cmd)
if err != nil {
return err
}
i.log.Debugf("Using flags: %+v", flags)
i.log.Debugf("Loading configuration file from %q", i.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, flags.force)
i.log.Debugf("Loading configuration file from %q", i.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, i.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -146,7 +179,7 @@ func (i *initCmd) initialize(
if err != nil {
return err
}
if !flags.force {
if !i.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err
}
@ -183,7 +216,7 @@ func (i *initCmd) initialize(
return fmt.Errorf("creating new validator: %w", err)
}
i.log.Debugf("Created a new validator")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.pf, i.log, i.fileHandler)
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.flags.pathPrefixer, i.log, i.fileHandler)
if err != nil {
return err
}
@ -211,7 +244,7 @@ func (i *initCmd) initialize(
MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: flags.conformance,
ConformanceMode: i.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
@ -228,7 +261,7 @@ func (i *initCmd) initialize(
if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.pf.PrefixPrintablePath(constants.ErrorLog))
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
}
}
return err
@ -236,7 +269,7 @@ func (i *initCmd) initialize(
i.log.Debugf("Initialization request succeeded")
bufferedOutput := &bytes.Buffer{}
if err := i.writeOutput(stateFile, resp, flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
if err := i.writeOutput(stateFile, resp, i.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return err
}
@ -250,9 +283,9 @@ func (i *initCmd) initialize(
i.spinner.Start("Installing Kubernetes components ", false)
options := helm.Options{
Force: flags.force,
Conformance: flags.conformance,
HelmWaitMode: flags.helmWaitMode,
Force: i.flags.force,
Conformance: i.flags.conformance,
HelmWaitMode: i.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive,
}
helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log)
@ -457,7 +490,7 @@ func (i *initCmd) writeOutput(
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", i.pf.PrefixPrintablePath(constants.AdminConfFilename))
writeRow(tw, "Kubernetes configuration", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush()
fmt.Fprintln(wr)
@ -485,7 +518,7 @@ func (i *initCmd) writeOutput(
if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
i.log.Debugf("Kubeconfig written to %s", i.pf.PrefixPrintablePath(constants.AdminConfFilename))
i.log.Debugf("Kubeconfig written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig {
if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil {
@ -500,7 +533,7 @@ func (i *initCmd) writeOutput(
return fmt.Errorf("writing Constellation state file: %w", err)
}
i.log.Debugf("Constellation state file written to %s", i.pf.PrefixPrintablePath(constants.StateFilename))
i.log.Debugf("Constellation state file written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
@ -528,57 +561,6 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
fmt.Fprint(wr, col1, "\t", col2, "\n")
}
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func (i *initCmd) evalFlagArgs(cmd *cobra.Command) (initFlags, error) {
conformance, err := cmd.Flags().GetBool("conformance")
if err != nil {
return initFlags{}, fmt.Errorf("parsing conformance flag: %w", err)
}
i.log.Debugf("Conformance flag is %t", conformance)
skipHelmWait, err := cmd.Flags().GetBool("skip-helm-wait")
if err != nil {
return initFlags{}, fmt.Errorf("parsing skip-helm-wait flag: %w", err)
}
helmWaitMode := helm.WaitModeAtomic
if skipHelmWait {
helmWaitMode = helm.WaitModeNone
}
i.log.Debugf("Helm wait flag is %t", skipHelmWait)
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return initFlags{}, fmt.Errorf("parsing config path flag: %w", err)
}
i.pf = pathprefix.New(workDir)
mergeConfigs, err := cmd.Flags().GetBool("merge-kubeconfig")
if err != nil {
return initFlags{}, fmt.Errorf("parsing merge-kubeconfig flag: %w", err)
}
i.log.Debugf("Merge kubeconfig flag is %t", mergeConfigs)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return initFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
i.log.Debugf("force flag is %t", force)
return initFlags{
conformance: conformance,
helmWaitMode: helmWaitMode,
force: force,
mergeConfigs: mergeConfigs,
}, nil
}
// initFlags are the resulting values of flag preprocessing.
type initFlags struct {
conformance bool
helmWaitMode helm.WaitMode
force bool
mergeConfigs bool
}
// generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk
@ -599,7 +581,7 @@ func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, e
if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.pf.PrefixPrintablePath(constants.MasterSecretFilename))
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil
}

View File

@ -223,10 +223,6 @@ func TestInitialize(t *testing.T) {
var errOut bytes.Buffer
cmd.SetErr(&errOut)
// Flags
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
// File system preparation
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
@ -249,6 +245,8 @@ func TestInitialize(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t))
i.flags.force = true
err := i.initialize(
cmd,
newDialer,
@ -442,15 +440,15 @@ func TestWriteOutput(t *testing.T) {
require.NoError(afs.Remove(constants.AdminConfFilename))
// test custom workspace
i.pf = pathprefix.New("/some/path")
i.flags.pathPrefixer = pathprefix.New("/some/path")
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err)
assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), i.pf.PrefixPrintablePath(constants.AdminConfFilename))
assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
out.Reset()
// File is written to current working dir, we simply pass the workspace for generating readable user output
require.NoError(afs.Remove(constants.AdminConfFilename))
i.pf = pathprefix.PathPrefixer{}
i.flags.pathPrefixer = pathprefix.PathPrefixer{}
// test config merging
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)

View File

@ -14,13 +14,11 @@ import (
"net"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -51,6 +49,8 @@ func newMiniUpCmd() *cobra.Command {
type miniUpCmd struct {
log debugLog
configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler
flags rootFlags
}
func runUp(cmd *cobra.Command, _ []string) error {
@ -66,7 +66,14 @@ func runUp(cmd *cobra.Command, _ []string) error {
defer spinner.Stop()
creator := cloudcmd.NewCreator(spinner)
m := &miniUpCmd{log: log, configFetcher: attestationconfigapi.NewFetcher()}
m := &miniUpCmd{
log: log,
configFetcher: attestationconfigapi.NewFetcher(),
fileHandler: file.NewHandler(afero.NewOsFs()),
}
if err := m.flags.parse(cmd.Flags()); err != nil {
return err
}
return m.up(cmd, creator, spinner)
}
@ -75,22 +82,15 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
return fmt.Errorf("system requirements not met: %w", err)
}
flags, err := m.parseUpFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs())
// create config if not passed as flag and set default values
config, err := m.prepareConfig(cmd, fileHandler, flags)
config, err := m.prepareConfig(cmd)
if err != nil {
return fmt.Errorf("preparing config: %w", err)
}
// create cluster
spinner.Start("Creating cluster in QEMU ", false)
err = m.createMiniCluster(cmd.Context(), fileHandler, creator, config, flags)
err = m.createMiniCluster(cmd.Context(), creator, config)
spinner.Stop()
if err != nil {
return fmt.Errorf("creating cluster: %w", err)
@ -105,7 +105,7 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
cmd.Printf("\tvirsh -c %s\n\n", connectURI)
// initialize cluster
if err := m.initializeMiniCluster(cmd, fileHandler, spinner); err != nil {
if err := m.initializeMiniCluster(cmd, spinner); err != nil {
return fmt.Errorf("initializing cluster: %w", err)
}
m.log.Debugf("Initialized cluster")
@ -113,8 +113,8 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
}
// prepareConfig reads a given config, or creates a new minimal QEMU config.
func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler, flags upFlags) (*config.Config, error) {
_, err := fileHandler.Stat(constants.ConfigFilename)
func (m *miniUpCmd) prepareConfig(cmd *cobra.Command) (*config.Config, error) {
_, err := m.fileHandler.Stat(constants.ConfigFilename)
if err == nil {
// config already exists, prompt user if they want to use this file
cmd.PrintErrln("A config file already exists in the configured workspace.")
@ -123,7 +123,7 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler,
return nil, err
}
if ok {
return m.prepareExistingConfig(cmd, fileHandler, flags)
return m.prepareExistingConfig(cmd)
}
// user declined to reuse config file, prompt if they want to overwrite it
@ -146,11 +146,11 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler,
}
m.log.Debugf("Prepared configuration")
return config, fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptOverwrite)
return config, m.fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptOverwrite)
}
func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command, fileHandler file.Handler, flags upFlags) (*config.Config, error) {
conf, err := config.New(fileHandler, constants.ConfigFilename, m.configFetcher, flags.force)
func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command) (*config.Config, error) {
conf, err := config.New(m.fileHandler, constants.ConfigFilename, m.configFetcher, m.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -165,13 +165,13 @@ func (m *miniUpCmd) prepareExistingConfig(cmd *cobra.Command, fileHandler file.H
}
// createMiniCluster creates a new cluster using the given config.
func (m *miniUpCmd) createMiniCluster(ctx context.Context, fileHandler file.Handler, creator cloudCreator, config *config.Config, flags upFlags) error {
func (m *miniUpCmd) createMiniCluster(ctx context.Context, creator cloudCreator, config *config.Config) error {
m.log.Debugf("Creating mini cluster")
opts := cloudcmd.CreateOptions{
Provider: cloudprovider.QEMU,
Config: config,
TFWorkspace: constants.TerraformWorkingDir,
TFLogLevel: flags.tfLogLevel,
TFLogLevel: m.flags.tfLogLevel,
}
infraState, err := creator.Create(ctx, opts)
if err != nil {
@ -184,11 +184,11 @@ func (m *miniUpCmd) createMiniCluster(ctx context.Context, fileHandler file.Hand
SetInfrastructure(infraState)
m.log.Debugf("Cluster state file contains %v", stateFile)
return stateFile.WriteToFile(fileHandler, constants.StateFilename)
return stateFile.WriteToFile(m.fileHandler, constants.StateFilename)
}
// initializeMiniCluster initializes a QEMU cluster.
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.Handler, spinner spinnerInterf) (retErr error) {
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInterf) (retErr error) {
m.log.Debugf("Initializing mini cluster")
// clean up cluster resources if initialization fails
defer func() {
@ -214,12 +214,17 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.H
defer log.Sync()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log)
return kubecmd.New(w, kubeConfig, m.fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
i := newInitCmd(m.fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher,
newAttestationApplier, newHelmClient); err != nil {
return err
@ -227,37 +232,3 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, fileHandler file.H
m.log.Debugf("Initialized mini cluster")
return nil
}
type upFlags struct {
force bool
tfLogLevel terraform.LogLevel
}
func (m *miniUpCmd) parseUpFlags(cmd *cobra.Command) (upFlags, error) {
m.log.Debugf("Preparing configuration")
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return upFlags{}, fmt.Errorf("parsing config string: %w", err)
}
m.log.Debugf("Workspace set to %q", workDir)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upFlags{}, fmt.Errorf("parsing force bool: %w", err)
}
m.log.Debugf("force flag is %q", force)
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
m.log.Debugf("Terraform logs will be written into %s at level %s", pathprefix.New(workDir).PrefixPrintablePath(constants.TerraformLogFile), logLevel.String())
return upFlags{
force: force,
tfLogLevel: logLevel,
}, nil
}

View File

@ -16,7 +16,6 @@ import (
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
@ -31,6 +30,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// NewRecoverCmd returns a new cobra.Command for the recover command.
@ -47,10 +47,28 @@ func NewRecoverCmd() *cobra.Command {
return cmd
}
type recoverFlags struct {
rootFlags
endpoint string
}
func (f *recoverFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
endpoint, err := flags.GetString("endpoint")
if err != nil {
return fmt.Errorf("getting 'endpoint' flag: %w", err)
}
f.endpoint = endpoint
return nil
}
type recoverCmd struct {
log debugLog
configFetcher attestationconfigapi.Fetcher
pf pathprefix.PathPrefixer
flags recoverFlags
}
func runRecover(cmd *cobra.Command, _ []string) error {
@ -64,6 +82,10 @@ func runRecover(cmd *cobra.Command, _ []string) error {
return dialer.New(nil, validator, &net.Dialer{})
}
r := &recoverCmd{log: log, configFetcher: attestationconfigapi.NewFetcher()}
if err := r.flags.parse(cmd.Flags()); err != nil {
return err
}
r.log.Debugf("Using flags: %+v", r.flags)
return r.recover(cmd, fileHandler, 5*time.Second, &recoverDoer{log: r.log}, newDialer)
}
@ -71,20 +93,14 @@ func (r *recoverCmd) recover(
cmd *cobra.Command, fileHandler file.Handler, interval time.Duration,
doer recoverDoerInterface, newDialer func(validator atls.Validator) *dialer.Dialer,
) error {
flags, err := r.parseRecoverFlags(cmd, fileHandler)
if err != nil {
return err
}
r.log.Debugf("Using flags: %+v", flags)
var masterSecret uri.MasterSecret
r.log.Debugf("Loading master secret file from %s", r.pf.PrefixPrintablePath(constants.MasterSecretFilename))
r.log.Debugf("Loading master secret file from %s", r.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
if err := fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret); err != nil {
return err
}
r.log.Debugf("Loading configuration file from %q", r.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, r.configFetcher, flags.force)
r.log.Debugf("Loading configuration file from %q", r.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, r.configFetcher, r.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -99,15 +115,26 @@ func (r *recoverCmd) recover(
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
}
conf.UpdateMAAURL(flags.maaURL)
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
endpoint, err := r.parseEndpoint(stateFile)
if err != nil {
return err
}
if stateFile.Infrastructure.Azure != nil {
conf.UpdateMAAURL(stateFile.Infrastructure.Azure.AttestationURL)
}
r.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), r.log)
if err != nil {
return fmt.Errorf("creating new validator: %w", err)
}
r.log.Debugf("Created a new validator")
doer.setDialer(newDialer(validator), flags.endpoint)
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
doer.setDialer(newDialer(validator), endpoint)
r.log.Debugf("Set dialer for endpoint %s", endpoint)
doer.setURIs(masterSecret.EncodeToURI(), uri.NoStoreURI)
r.log.Debugf("Set secrets")
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
@ -160,6 +187,18 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
return err
}
func (r *recoverCmd) parseEndpoint(state *state.State) (string, error) {
endpoint := r.flags.endpoint
if endpoint == "" {
endpoint = state.Infrastructure.ClusterEndpoint
}
endpoint, err := addPortIfMissing(endpoint, constants.RecoveryPort)
if err != nil {
return "", fmt.Errorf("validating cluster endpoint: %w", err)
}
return endpoint, nil
}
type recoverDoerInterface interface {
Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string)
@ -209,55 +248,3 @@ func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
d.kmsURI = kmsURI
d.storageURI = storageURI
}
type recoverFlags struct {
endpoint string
maaURL string
force bool
}
func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing config path argument: %w", err)
}
r.log.Debugf("Workspace set to %q", workDir)
r.pf = pathprefix.New(workDir)
endpoint, err := cmd.Flags().GetString("endpoint")
r.log.Debugf("Endpoint flag is %s", endpoint)
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
var attestationURL string
stateFile := state.New()
if endpoint == "" {
stateFile, err = state.ReadFromFile(fileHandler, constants.StateFilename)
if err != nil {
return recoverFlags{}, fmt.Errorf("reading state file: %w", err)
}
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
endpoint, err = addPortIfMissing(endpoint, constants.RecoveryPort)
if err != nil {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
}
r.log.Debugf("Endpoint value after parsing is %s", endpoint)
if stateFile.Infrastructure.Azure != nil {
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
}
return recoverFlags{
endpoint: endpoint,
maaURL: attestationURL,
force: force,
}, nil
}

View File

@ -140,12 +140,9 @@ func TestRecover(t *testing.T) {
cmd := NewRecoverCmd()
cmd.SetContext(context.Background())
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
out := &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(out)
require.NoError(cmd.Flags().Set("endpoint", tc.endpoint))
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
@ -156,13 +153,25 @@ func TestRecover(t *testing.T) {
}
require.NoError(fileHandler.WriteJSON(
"constellation-mastersecret.json",
constants.MasterSecretFilename,
uri.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
file.OptNone,
))
require.NoError(fileHandler.WriteYAML(
constants.StateFilename,
state.New(),
file.OptNone,
))
newDialer := func(atls.Validator) *dialer.Dialer { return nil }
r := &recoverCmd{log: logger.NewTest(t), configFetcher: stubAttestationFetcher{}}
r := &recoverCmd{
log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{},
flags: recoverFlags{
rootFlags: rootFlags{force: true},
endpoint: tc.endpoint,
},
}
err := r.recover(cmd, fileHandler, time.Millisecond, tc.doer, newDialer)
if tc.wantErr {
assert.Error(err)
@ -183,68 +192,6 @@ func TestRecover(t *testing.T) {
}
}
func TestParseRecoverFlags(t *testing.T) {
testCases := map[string]struct {
args []string
wantFlags recoverFlags
writeStateFile bool
wantErr bool
}{
"no flags": {
wantFlags: recoverFlags{
endpoint: "192.0.2.42:9999",
},
writeStateFile: true,
},
"no flags, no ID file": {
wantFlags: recoverFlags{
endpoint: "192.0.2.42:9999",
},
wantErr: true,
},
"invalid endpoint": {
args: []string{"-e", "192.0.2.42:2:2"},
wantErr: true,
},
"all args set": {
args: []string{"-e", "192.0.2.42:2", "--workspace", "./constellation-workspace"},
wantFlags: recoverFlags{
endpoint: "192.0.2.42:2",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := NewRecoverCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", false, "") // register persistent flag manually
require.NoError(cmd.ParseFlags(tc.args))
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.writeStateFile {
require.NoError(
state.New().
SetInfrastructure(state.Infrastructure{ClusterEndpoint: "192.0.2.42"}).
WriteToFile(fileHandler, constants.StateFilename),
)
}
r := &recoverCmd{log: logger.NewTest(t)}
flags, err := r.parseRecoverFlags(cmd, fileHandler)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestDoRecovery(t *testing.T) {
testCases := map[string]struct {
recoveryServer *stubRecoveryServer

View File

@ -45,11 +45,6 @@ func runStatus(cmd *cobra.Command, _ []string) error {
}
defer log.Sync()
flags, err := parseStatusFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs())
helmClient, err := helm.NewReleaseVersionClient(constants.AdminConfFilename, log)
@ -61,55 +56,61 @@ func runStatus(cmd *cobra.Command, _ []string) error {
}
fetcher := attestationconfigapi.NewFetcher()
conf, err := config.New(fileHandler, constants.ConfigFilename, fetcher, flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
variant := conf.GetAttestationConfig().GetVariant()
kubeClient, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log)
if err != nil {
return fmt.Errorf("setting up kubernetes client: %w", err)
}
output, err := status(cmd.Context(), helmVersionGetter, kubeClient, variant)
if err != nil {
return fmt.Errorf("getting status: %w", err)
s := statusCmd{log: log, fileHandler: fileHandler}
if err := s.flags.parse(cmd.Flags()); err != nil {
return err
}
return s.status(cmd, helmVersionGetter, kubeClient, fetcher)
}
cmd.Print(output)
return nil
type statusCmd struct {
log debugLog
fileHandler file.Handler
flags rootFlags
}
// status queries the cluster for the relevant status information and returns the output string.
func status(ctx context.Context, getHelmVersions func() (fmt.Stringer, error), kubeClient kubeCmd, attestVariant variant.Variant,
) (string, error) {
nodeVersion, err := kubeClient.GetConstellationVersion(ctx)
if err != nil {
return "", fmt.Errorf("getting constellation version: %w", err)
func (s *statusCmd) status(
cmd *cobra.Command, getHelmVersions func() (fmt.Stringer, error),
kubeClient kubeCmd, fetcher attestationconfigapi.Fetcher,
) error {
conf, err := config.New(s.fileHandler, constants.ConfigFilename, fetcher, s.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
attestationConfig, err := kubeClient.GetClusterAttestationConfig(ctx, attestVariant)
nodeVersion, err := kubeClient.GetConstellationVersion(cmd.Context())
if err != nil {
return "", fmt.Errorf("getting attestation config: %w", err)
return fmt.Errorf("getting constellation version: %w", err)
}
attestationConfig, err := kubeClient.GetClusterAttestationConfig(cmd.Context(), conf.GetAttestationConfig().GetVariant())
if err != nil {
return fmt.Errorf("getting attestation config: %w", err)
}
prettyYAML, err := yaml.Marshal(attestationConfig)
if err != nil {
return "", fmt.Errorf("marshalling attestation config: %w", err)
return fmt.Errorf("marshalling attestation config: %w", err)
}
serviceVersions, err := getHelmVersions()
if err != nil {
return "", fmt.Errorf("getting service versions: %w", err)
return fmt.Errorf("getting service versions: %w", err)
}
status, err := kubeClient.ClusterStatus(ctx)
status, err := kubeClient.ClusterStatus(cmd.Context())
if err != nil {
return "", fmt.Errorf("getting cluster status: %w", err)
return fmt.Errorf("getting cluster status: %w", err)
}
return statusOutput(nodeVersion, serviceVersions, status, string(prettyYAML)), nil
cmd.Print(statusOutput(nodeVersion, serviceVersions, status, string(prettyYAML)))
return nil
}
// statusOutput creates the status cmd output string by formatting the received information.
@ -167,26 +168,6 @@ func targetVersionsString(target kubecmd.NodeVersion) string {
return builder.String()
}
type statusFlags struct {
workspace string
force bool
}
func parseStatusFlags(cmd *cobra.Command) (statusFlags, error) {
workspace, err := cmd.Flags().GetString("workspace")
if err != nil {
return statusFlags{}, fmt.Errorf("getting config flag: %w", err)
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return statusFlags{}, fmt.Errorf("getting config flag: %w", err)
}
return statusFlags{
workspace: workspace,
force: force,
}, nil
}
type kubeCmd interface {
ClusterStatus(ctx context.Context) (map[string]kubecmd.NodeStatus, error)
GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error)

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd
import (
"bytes"
"context"
"fmt"
"testing"
@ -14,8 +15,12 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
@ -63,7 +68,6 @@ func TestStatus(t *testing.T) {
testCases := map[string]struct {
kubeClient stubKubeClient
attestVariant variant.Variant
expectedOutput string
wantErr bool
}{
@ -104,7 +108,6 @@ func TestStatus(t *testing.T) {
},
},
},
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput,
},
"one of two nodes not upgraded": {
@ -157,7 +160,6 @@ func TestStatus(t *testing.T) {
},
},
},
attestVariant: variant.QEMUVTPM{},
expectedOutput: inProgressOutput,
},
"error getting node status": {
@ -183,7 +185,6 @@ func TestStatus(t *testing.T) {
},
},
},
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput,
wantErr: true,
},
@ -211,7 +212,6 @@ func TestStatus(t *testing.T) {
},
},
},
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput,
wantErr: true,
},
@ -248,7 +248,6 @@ func TestStatus(t *testing.T) {
}),
attestationErr: assert.AnError,
},
attestVariant: variant.QEMUVTPM{},
expectedOutput: successOutput,
wantErr: true,
},
@ -259,19 +258,31 @@ func TestStatus(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
variant := variant.AWSNitroTPM{}
output, err := status(
context.Background(),
cmd := NewStatusCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cfg, err := createConfigWithAttestationVariant(cloudprovider.QEMU, "", variant.QEMUVTPM{})
require.NoError(err)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
s := statusCmd{fileHandler: fileHandler}
err = s.status(
cmd,
stubGetVersions(versionsOutput),
tc.kubeClient,
variant,
stubAttestationFetcher{},
)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.expectedOutput, output)
assert.Equal(tc.expectedOutput, out.String())
})
}
}

View File

@ -13,10 +13,9 @@ import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
)
@ -35,9 +34,26 @@ func NewTerminateCmd() *cobra.Command {
return cmd
}
type terminateFlags struct {
rootFlags
yes bool
}
func (f *terminateFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
yes, err := flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.yes = yes
return nil
}
// runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return fmt.Errorf("creating spinner: %w", err)
@ -45,18 +61,27 @@ func runTerminate(cmd *cobra.Command, _ []string) error {
defer spinner.Stop()
terminator := cloudcmd.NewTerminator()
return terminate(cmd, terminator, fileHandler, spinner)
logger, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
t := &terminateCmd{log: logger, fileHandler: file.NewHandler(afero.NewOsFs())}
if err := t.flags.parse(cmd.Flags()); err != nil {
return err
}
return t.terminate(cmd, terminator, spinner)
}
func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler, spinner spinnerInterf,
) error {
flags, err := parseTerminateFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
pf := pathprefix.New(flags.workspace)
type terminateCmd struct {
log debugLog
fileHandler file.Handler
flags terminateFlags
}
if !flags.yes {
func (t *terminateCmd) terminate(cmd *cobra.Command, terminator cloudTerminator, spinner spinnerInterf) error {
if !t.flags.yes {
cmd.Println("You are about to terminate a Constellation cluster.")
cmd.Println("All of its associated resources will be DESTROYED.")
cmd.Println("This action is irreversible and ALL DATA WILL BE LOST.")
@ -71,7 +96,7 @@ func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.
}
spinner.Start("Terminating", false)
err = terminator.Terminate(cmd.Context(), constants.TerraformWorkingDir, flags.logLevel)
err := terminator.Terminate(cmd.Context(), constants.TerraformWorkingDir, t.flags.tfLogLevel)
spinner.Stop()
if err != nil {
return fmt.Errorf("terminating Constellation cluster: %w", err)
@ -80,44 +105,13 @@ func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.
cmd.Println("Your Constellation cluster was terminated successfully.")
var removeErr error
if err := fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", pf.PrefixPrintablePath(constants.AdminConfFilename)))
if err := t.fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", t.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)))
}
if err := fileHandler.Remove(constants.StateFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", pf.PrefixPrintablePath(constants.StateFilename)))
if err := t.fileHandler.Remove(constants.StateFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
removeErr = errors.Join(err, fmt.Errorf("failed to remove file: '%s', please remove it manually", t.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)))
}
return removeErr
}
type terminateFlags struct {
yes bool
workspace string
logLevel terraform.LogLevel
}
func parseTerminateFlags(cmd *cobra.Command) (terminateFlags, error) {
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing yes bool: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
workspace, err := cmd.Flags().GetString("workspace")
if err != nil {
return terminateFlags{}, fmt.Errorf("parsing workspace string: %w", err)
}
return terminateFlags{
yes: yes,
workspace: workspace,
logLevel: logLevel,
}, nil
}

View File

@ -14,6 +14,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -134,18 +135,17 @@ func TestTerminate(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
// register persistent flags manually
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().String("workspace", "", "")
require.NotNil(tc.setupFs)
fileHandler := file.NewHandler(tc.setupFs(require, tc.stateFile))
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
tCmd := &terminateCmd{
log: logger.NewTest(t),
fileHandler: fileHandler,
flags: terminateFlags{
yes: tc.yesFlag,
},
}
err := terminate(cmd, tc.terminator, fileHandler, &nopSpinner{})
err := tCmd.terminate(cmd, tc.terminator, &nopSpinner{})
if tc.wantErr {
assert.Error(err)

View File

@ -16,7 +16,6 @@ import (
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
@ -33,6 +32,7 @@ import (
"github.com/rogpeppe/go-internal/diff"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
)
@ -76,12 +76,62 @@ func newUpgradeApplyCmd() *cobra.Command {
return cmd
}
func runUpgradeApply(cmd *cobra.Command, _ []string) error {
flags, err := parseUpgradeApplyFlags(cmd)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
type upgradeApplyFlags struct {
rootFlags
yes bool
upgradeTimeout time.Duration
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
func (f *upgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
return nil
}
func runUpgradeApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
@ -98,13 +148,18 @@ func runUpgradeApply(cmd *cobra.Command, _ []string) error {
configFetcher := attestationconfigapi.NewFetcher()
var flags upgradeApplyFlags
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
// Set up terraform upgrader
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.terraformLogLevel,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
@ -122,9 +177,10 @@ func runUpgradeApply(cmd *cobra.Command, _ []string) error {
clusterUpgrader: clusterUpgrader,
configFetcher: configFetcher,
fileHandler: fileHandler,
flags: flags,
log: log,
}
return applyCmd.upgradeApply(cmd, upgradeDir, flags)
return applyCmd.upgradeApply(cmd, upgradeDir)
}
type upgradeApplyCmd struct {
@ -133,11 +189,12 @@ type upgradeApplyCmd struct {
clusterUpgrader clusterUpgrader
configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler
flags upgradeApplyFlags
log debugLog
}
func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, flags upgradeApplyFlags) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, flags.force)
func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, u.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -147,7 +204,7 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
}
if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) {
cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.")
if !flags.yes {
if !u.flags.yes {
yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
@ -158,7 +215,7 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
}
}
}
conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), flags.yes)
conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), u.flags.yes)
if err != nil {
return err
}
@ -168,21 +225,21 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
return fmt.Errorf("reading state file: %w", err)
}
if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt, flags); err != nil {
if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("upgrading measurements: %w", err)
}
// If infrastructure phase is skipped, we expect the new infrastructure
// to be in the Terraform configuration already. Otherwise, perform
// the Terraform migrations.
if !flags.skipPhases.contains(skipInfrastructurePhase) {
if !u.flags.skipPhases.contains(skipInfrastructurePhase) {
migrationRequired, err := u.planTerraformMigration(cmd, conf)
if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err)
}
if migrationRequired {
postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir, flags)
postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir)
if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err)
}
@ -217,8 +274,8 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
}
var upgradeErr *compatibility.InvalidUpgradeError
if !flags.skipPhases.contains(skipHelmPhase) {
err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir, flags)
if !u.flags.skipPhases.contains(skipHelmPhase) {
err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir)
switch {
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
@ -228,10 +285,10 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string, fl
return fmt.Errorf("upgrading services: %w", err)
}
}
skipImageUpgrade := flags.skipPhases.contains(skipImagePhase)
skipK8sUpgrade := flags.skipPhases.contains(skipK8sPhase)
skipImageUpgrade := u.flags.skipPhases.contains(skipImagePhase)
skipK8sUpgrade := u.flags.skipPhases.contains(skipK8sPhase)
if !(skipImageUpgrade && skipK8sUpgrade) {
err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, flags.force, skipImageUpgrade, skipK8sUpgrade)
err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, u.flags.force, skipImageUpgrade, skipK8sUpgrade)
switch {
case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
@ -284,12 +341,11 @@ func (u *upgradeApplyCmd) planTerraformMigration(cmd *cobra.Command, conf *confi
// migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration
// of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned.
// If no migration is required, the current (pre-upgrade) infrastructure state is returned.
func (u *upgradeApplyCmd) migrateTerraform(
cmd *cobra.Command, conf *config.Config, upgradeDir string, flags upgradeApplyFlags,
func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string,
) (state.Infrastructure, error) {
// If there are any Terraform migrations to apply, ask for confirmation
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !flags.yes {
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
@ -317,8 +373,8 @@ func (u *upgradeApplyCmd) migrateTerraform(
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n",
flags.pf.PrefixPrintablePath(constants.StateFilename),
flags.pf.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
u.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
u.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
)
return infraState, nil
}
@ -347,7 +403,7 @@ func validK8sVersion(cmd *cobra.Command, version string, yes bool) (validVersion
// confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements.
// If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config.
func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte, flags upgradeApplyFlags,
cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte,
) error {
clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil {
@ -369,7 +425,7 @@ func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
}
cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr)
if !flags.yes {
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
@ -387,21 +443,20 @@ func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
}
func (u *upgradeApplyCmd) handleServiceUpgrade(
cmd *cobra.Command, conf *config.Config, stateFile *state.State,
upgradeDir string, flags upgradeApplyFlags,
cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string,
) error {
var secret uri.MasterSecret
if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil {
return fmt.Errorf("reading master secret: %w", err)
}
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, flags.pf, u.log, u.fileHandler)
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, u.flags.pathPrefixer, u.log, u.fileHandler)
if err != nil {
return fmt.Errorf("getting service account URI: %w", err)
}
options := helm.Options{
Force: flags.force,
Conformance: flags.conformance,
HelmWaitMode: flags.helmWaitMode,
Force: u.flags.force,
Conformance: u.flags.conformance,
HelmWaitMode: u.flags.helmWaitMode,
}
prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) {
@ -422,7 +477,7 @@ func (u *upgradeApplyCmd) handleServiceUpgrade(
if !errors.Is(err, helm.ErrConfirmationMissing) {
return fmt.Errorf("upgrading charts with deny destructive mode: %w", err)
}
if !flags.yes {
if !u.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil {
@ -463,86 +518,6 @@ func (u *upgradeApplyCmd) handleServiceUpgrade(
return nil
}
func parseUpgradeApplyFlags(cmd *cobra.Command) (upgradeApplyFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return upgradeApplyFlags{}, err
}
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return upgradeApplyFlags{}, err
}
timeout, err := cmd.Flags().GetDuration("timeout")
if err != nil {
return upgradeApplyFlags{}, err
}
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
conformance, err := cmd.Flags().GetBool("conformance")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing conformance flag: %w", err)
}
skipHelmWait, err := cmd.Flags().GetBool("skip-helm-wait")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing skip-helm-wait flag: %w", err)
}
helmWaitMode := helm.WaitModeAtomic
if skipHelmWait {
helmWaitMode = helm.WaitModeNone
}
rawSkipPhases, err := cmd.Flags().GetStringSlice("skip-phases")
if err != nil {
return upgradeApplyFlags{}, fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return upgradeApplyFlags{}, fmt.Errorf("invalid phase %s", phase)
}
}
return upgradeApplyFlags{
pf: pathprefix.New(workDir),
yes: yes,
upgradeTimeout: timeout,
force: force,
terraformLogLevel: logLevel,
conformance: conformance,
helmWaitMode: helmWaitMode,
skipPhases: skipPhases,
}, nil
}
type upgradeApplyFlags struct {
pf pathprefix.PathPrefixer
yes bool
upgradeTimeout time.Duration
force bool
terraformLogLevel terraform.LogLevel
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
// skipPhases is a list of phases that can be skipped during the upgrade process.
type skipPhases []skipPhase

View File

@ -227,10 +227,11 @@ func TestUpgradeApply(t *testing.T) {
clusterUpgrader: tc.terraformUpgrader,
log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{},
flags: tc.flags,
fileHandler: fh,
}
err := upgrader.upgradeApply(cmd, "test", tc.flags)
err := upgrader.upgradeApply(cmd, "test")
if tc.wantErr {
assert.Error(err)
return
@ -247,16 +248,20 @@ func TestUpgradeApply(t *testing.T) {
}
func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) {
require := require.New(t)
cmd := newUpgradeApplyCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
cmd.Flags().String("tf-log", "NONE", "") // register persistent flag manually
require.NoError(t, cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image"))
result, err := parseUpgradeApplyFlags(cmd)
if err != nil {
t.Fatalf("Error while parsing flags: %v", err)
}
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, result.skipPhases)
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("force", true, "")
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().Bool("debug", false, "")
require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image"))
var flags upgradeApplyFlags
err := flags.parse(cmd.Flags())
require.NoError(err)
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases)
}
type stubKubernetesUpgrader struct {

View File

@ -38,6 +38,7 @@ import (
"github.com/siderolabs/talos/pkg/machinery/config/encoder"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/mod/semver"
)
@ -57,6 +58,36 @@ func newUpgradeCheckCmd() *cobra.Command {
return cmd
}
type upgradeCheckFlags struct {
rootFlags
updateConfig bool
ref string
stream string
}
func (f *upgradeCheckFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
updateConfig, err := flags.GetBool("update-config")
if err != nil {
return fmt.Errorf("getting 'update-config' flag: %w", err)
}
f.updateConfig = updateConfig
f.ref, err = flags.GetString("ref")
if err != nil {
return fmt.Errorf("getting 'ref' flag: %w", err)
}
f.stream, err = flags.GetString("stream")
if err != nil {
return fmt.Errorf("getting 'stream' flag: %w", err)
}
return nil
}
func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
@ -64,8 +95,8 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
}
defer log.Sync()
flags, err := parseUpgradeCheckFlags(cmd)
if err != nil {
var flags upgradeCheckFlags
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
@ -77,7 +108,7 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.terraformLogLevel,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
@ -111,46 +142,11 @@ func runUpgradeCheck(cmd *cobra.Command, _ []string) error {
upgradeDir: upgradeDir,
terraformChecker: tfClient,
fileHandler: fileHandler,
flags: flags,
log: log,
}
return up.upgradeCheck(cmd, attestationconfigapi.NewFetcher(), flags)
}
func parseUpgradeCheckFlags(cmd *cobra.Command) (upgradeCheckFlags, error) {
force, err := cmd.Flags().GetBool("force")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing force bool: %w", err)
}
updateConfig, err := cmd.Flags().GetBool("update-config")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing update-config bool: %w", err)
}
ref, err := cmd.Flags().GetString("ref")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing ref string: %w", err)
}
stream, err := cmd.Flags().GetString("stream")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing stream string: %w", err)
}
logLevelString, err := cmd.Flags().GetString("tf-log")
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing tf-log string: %w", err)
}
logLevel, err := terraform.ParseLogLevel(logLevelString)
if err != nil {
return upgradeCheckFlags{}, fmt.Errorf("parsing Terraform log level %s: %w", logLevelString, err)
}
return upgradeCheckFlags{
force: force,
updateConfig: updateConfig,
ref: ref,
stream: stream,
terraformLogLevel: logLevel,
}, nil
return up.upgradeCheck(cmd, attestationconfigapi.NewFetcher())
}
type upgradeCheckCmd struct {
@ -159,12 +155,13 @@ type upgradeCheckCmd struct {
collect collector
terraformChecker terraformChecker
fileHandler file.Handler
flags upgradeCheckFlags
log debugLog
}
// upgradePlan plans an upgrade of a Constellation cluster.
func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationconfigapi.Fetcher, flags upgradeCheckFlags) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, fetcher, flags.force)
func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationconfigapi.Fetcher) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, fetcher, u.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -271,7 +268,7 @@ func (u *upgradeCheckCmd) upgradeCheck(cmd *cobra.Command, fetcher attestationco
// Using Print over Println as buildString already includes a trailing newline where necessary.
cmd.Print(updateMsg)
if flags.updateConfig {
if u.flags.updateConfig {
if err := upgrade.writeConfig(conf, u.fileHandler, constants.ConfigFilename); err != nil {
return fmt.Errorf("writing config: %w", err)
}
@ -725,14 +722,6 @@ func (v *versionCollector) filterCompatibleCLIVersions(ctx context.Context, cliP
return compatibleVersions, nil
}
type upgradeCheckFlags struct {
force bool
updateConfig bool
ref string
stream string
terraformLogLevel terraform.LogLevel
}
type kubernetesChecker interface {
GetConstellationVersion(ctx context.Context) (kubecmd.NodeVersion, error)
}

View File

@ -221,7 +221,7 @@ func TestUpgradeCheck(t *testing.T) {
cmd := newUpgradeCheckCmd()
err := checkCmd.upgradeCheck(cmd, stubAttestationFetcher{}, upgradeCheckFlags{})
err := checkCmd.upgradeCheck(cmd, stubAttestationFetcher{})
if tc.wantError {
assert.Error(err)
return

View File

@ -26,7 +26,6 @@ import (
tpmProto "github.com/google/go-tpm-tools/proto/tpm"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
@ -45,6 +44,7 @@ import (
"github.com/google/go-sev-guest/kds"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
)
@ -64,8 +64,39 @@ func NewVerifyCmd() *cobra.Command {
return cmd
}
type verifyFlags struct {
rootFlags
endpoint string
ownerID string
clusterID string
output string
}
func (f *verifyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
var err error
f.output, err = flags.GetString("output")
if err != nil {
return fmt.Errorf("getting 'output' flag: %w", err)
}
f.endpoint, err = flags.GetString("node-endpoint")
if err != nil {
return fmt.Errorf("getting 'node-endpoint' flag: %w", err)
}
f.clusterID, err = flags.GetString("cluster-id")
if err != nil {
return fmt.Errorf("getting 'cluster-id' flag: %w", err)
}
return nil
}
type verifyCmd struct {
log debugLog
fileHandler file.Handler
flags verifyFlags
log debugLog
}
func runVerify(cmd *cobra.Command, _ []string) error {
@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return nil, fmt.Errorf("invalid output value for formatter: %s", output)
}
}
v := &verifyCmd{log: log}
v := &verifyCmd{
fileHandler: fileHandler,
log: log,
}
if err := v.flags.parse(cmd.Flags()); err != nil {
return err
}
v.log.Debugf("Using flags: %+v", v.flags)
fetcher := attestationconfigapi.NewFetcher()
return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher)
return v.verify(cmd, verifyClient, formatterFactory, fetcher)
}
type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error)
func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
flags, err := c.parseVerifyFlags(cmd, fileHandler)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
c.log.Debugf("Using flags: %+v", flags)
c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force)
func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return fmt.Errorf("loading config file: %w", err)
}
conf.UpdateMAAURL(flags.maaURL)
stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
if err != nil {
return err
}
endpoint, err := c.validateEndpointFlag(cmd, stateFile)
if err != nil {
return err
}
var maaURL string
if stateFile.Infrastructure.Azure != nil {
maaURL = stateFile.Infrastructure.Azure.AttestationURL
}
conf.UpdateMAAURL(maaURL)
c.log.Debugf("Updating expected PCRs")
attConfig := conf.GetAttestationConfig()
if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil {
if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil {
return fmt.Errorf("updating expected PCRs: %w", err)
}
@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc, err := verifyClient.Verify(
cmd.Context(),
flags.endpoint,
endpoint,
&verifyproto.GetAttestationRequest{
Nonce: nonce,
},
@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
}
// certificates are only available for Azure
formatter, err := factory(flags.output, conf.GetProvider(), c.log)
formatter, err := factory(c.flags.output, conf.GetProvider(), c.log)
if err != nil {
return fmt.Errorf("creating formatter: %w", err)
}
@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc,
conf.Provider.Azure == nil,
attConfig.GetMeasurements(),
flags.maaURL,
maaURL,
)
if err != nil {
return fmt.Errorf("printing attestation document: %w", err)
@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return nil
}
func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err)
func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) {
ownerID, clusterID = c.flags.ownerID, c.flags.clusterID
if c.flags.clusterID == "" {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
clusterID = stateFile.ClusterValues.ClusterID
}
c.log.Debugf("Flag 'workspace' set to %q", workDir)
pf := pathprefix.New(workDir)
ownerID := ""
clusterID, err := cmd.Flags().GetString("cluster-id")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err)
}
c.log.Debugf("Flag 'cluster-id' set to %q", clusterID)
endpoint, err := cmd.Flags().GetString("node-endpoint")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
}
c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint)
force, err := cmd.Flags().GetBool("force")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("Flag 'force' set to %t", force)
output, err := cmd.Flags().GetString("output")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err)
}
c.log.Debugf("Flag 'output' set to %t", output)
// Get empty values from state file
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
isFileNotFound := errors.Is(err, afero.ErrFileNotFound)
if isFileNotFound {
c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename))
stateFile = state.New() // error compat
} else if err != nil {
return verifyFlags{}, fmt.Errorf("reading state file: %w", err)
}
emptyEndpoint := endpoint == ""
emptyIDs := ownerID == "" && clusterID == ""
if emptyEndpoint || emptyIDs {
c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename))
if emptyEndpoint {
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
if emptyIDs {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
clusterID = stateFile.ClusterValues.ClusterID
}
}
var attestationURL string
if stateFile.Infrastructure.Azure != nil {
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
if ownerID == "" {
// We don't want to print warnings until this is implemented again
// cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
}
// Validate
if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster")
}
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
return "", "", errors.New("cluster-id not provided to verify the cluster")
}
return verifyFlags{
endpoint: endpoint,
pf: pf,
ownerID: ownerID,
clusterID: clusterID,
output: output,
maaURL: attestationURL,
force: force,
}, nil
return ownerID, clusterID, nil
}
type verifyFlags struct {
endpoint string
ownerID string
clusterID string
maaURL string
output string
force bool
pf pathprefix.PathPrefixer
}
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) {
endpoint := c.flags.endpoint
if endpoint == "" {
return "", errors.New("endpoint is empty")
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return "", fmt.Errorf("validating endpoint argument: %w", err)
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
return endpoint, nil
}
// an attestationDocFormatter formats the attestation document.
@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) {
}
return instanceInfo, nil
}
func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View File

@ -58,6 +58,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -66,6 +67,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -74,6 +76,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{},
},
@ -81,6 +84,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -106,12 +110,14 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: ":::::",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
"neither owner id nor cluster id set": {
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -127,6 +133,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
skipConfigCreation: true,
wantErr: true,
@ -136,6 +143,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -144,6 +152,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: someErr},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -152,6 +161,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{formatErr: someErr},
wantErr: true,
@ -164,31 +174,28 @@ func TestVerify(t *testing.T) {
require := require.New(t)
cmd := NewVerifyCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
out := &bytes.Buffer{}
cmd.SetErr(out)
if tc.clusterIDFlag != "" {
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
}
if tc.nodeEndpointFlag != "" {
require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag))
}
fileHandler := file.NewHandler(afero.NewMemMapFs())
if !tc.skipConfigCreation {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
}
if tc.stateFile != nil {
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
}
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
v := &verifyCmd{log: logger.NewTest(t)}
v := &verifyCmd{
fileHandler: fileHandler,
log: logger.NewTest(t),
flags: verifyFlags{
clusterID: tc.clusterIDFlag,
endpoint: tc.nodeEndpointFlag,
},
}
formatterFac := func(_ string, _ cloudprovider.Provider, _ debugLog) (attestationDocFormatter, error) {
return tc.formatter, nil
}
err := v.verify(cmd, fileHandler, tc.protoClient, formatterFac, stubAttestationFetcher{})
err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{})
if tc.wantErr {
assert.Error(err)
} else {

5
go.mod
View File

@ -109,6 +109,7 @@ require (
github.com/sigstore/sigstore v1.7.1
github.com/spf13/afero v1.10.0
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
github.com/theupdateframework/go-tuf v0.5.2
github.com/tink-crypto/tink-go/v2 v2.0.0
@ -137,8 +138,6 @@ require (
sigs.k8s.io/yaml v1.3.0
)
require github.com/google/go-tdx-guest v0.2.2 // indirect
require (
cloud.google.com/go v0.110.2 // indirect
cloud.google.com/go/iam v1.1.0 // indirect
@ -234,6 +233,7 @@ require (
github.com/google/go-attestation v0.5.0 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-containerregistry v0.15.2 // indirect
github.com/google/go-tdx-guest v0.2.2 // indirect
github.com/google/go-tspi v0.3.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/logger v1.1.1 // indirect
@ -306,7 +306,6 @@ require (
github.com/shopspring/decimal v1.3.1 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect
github.com/transparency-dev/merkle v0.0.2 // indirect