/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package cmd import ( "fmt" "strings" "testing" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constellation/state" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/mod/semver" ) func TestParseKubernetesVersion(t *testing.T) { testCases := map[string]struct { version string wantErr bool }{ "default version": { version: "", }, "without v prefix": { version: strings.TrimPrefix(string(versions.Default), "v"), }, "K8s version without patch version": { version: semver.MajorMinor(string(versions.Default)), }, "K8s version with patch version": { version: string(versions.Default), }, "K8s version with invalid patch version": { version: func() string { s := string(versions.Default) return s[:len(s)-1] + "99" }(), wantErr: true, }, "outdated K8s version": { version: "v1.0.0", wantErr: true, }, "no semver": { version: "asdf", wantErr: true, }, "not supported": { version: "1111", wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) flags := newConfigGenerateCmd().Flags() if tc.version != "" { require.NoError(flags.Set("kubernetes", tc.version)) } version, err := parseK8sFlag(flags) if tc.wantErr { assert.Error(err) return } assert.NoError(err) assert.Equal(versions.Default, version) }) } } func TestConfigGenerateDefault(t *testing.T) { assert := assert.New(t) require := require.New(t) fileHandler := file.NewHandler(afero.NewMemMapFs()) cmd := newConfigGenerateCmd() cg := &configGenerateCmd{ log: logger.NewTest(t), flags: generateFlags{ attestationVariant: variant.Dummy{}, k8sVersion: versions.Default, }, } require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")) var readConfig config.Config err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig) assert.NoError(err) assert.Equal(*config.Default(), readConfig) _, err = state.ReadFromFile(fileHandler, constants.StateFilename) assert.NoError(err) } func TestConfigGenerateDefaultProviderSpecific(t *testing.T) { testCases := map[string]struct { provider cloudprovider.Provider rawProvider string }{ "aws": { provider: cloudprovider.AWS, }, "azure": { provider: cloudprovider.Azure, }, "gcp": { provider: cloudprovider.GCP, }, "openstack": { provider: cloudprovider.OpenStack, }, "stackit": { provider: cloudprovider.OpenStack, rawProvider: "stackit", }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) fileHandler := file.NewHandler(afero.NewMemMapFs()) cmd := newConfigGenerateCmd() wantConf := config.Default().WithOpenStackProviderDefaults(tc.rawProvider) wantConf.RemoveProviderAndAttestationExcept(tc.provider) cg := &configGenerateCmd{ log: logger.NewTest(t), flags: generateFlags{ attestationVariant: variant.Dummy{}, k8sVersion: versions.Default, }, } require.NoError(cg.configGenerate(cmd, fileHandler, tc.provider, tc.rawProvider)) var readConfig config.Config err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig) assert.NoError(err) assert.Equal(*wantConf, readConfig) stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename) assert.NoError(err) switch tc.provider { case cloudprovider.GCP: assert.NotNil(stateFile.Infrastructure.GCP) case cloudprovider.Azure: assert.NotNil(stateFile.Infrastructure.Azure) } }) } } func TestConfigGenerateDefaultExists(t *testing.T) { require := require.New(t) fileHandler := file.NewHandler(afero.NewMemMapFs()) require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone)) cmd := newConfigGenerateCmd() cg := &configGenerateCmd{ log: logger.NewTest(t), flags: generateFlags{attestationVariant: variant.Dummy{}}, } require.Error(cg.configGenerate(cmd, fileHandler, cloudprovider.Unknown, "")) } func TestNoValidProviderAttestationCombination(t *testing.T) { assert := assert.New(t) tests := []struct { provider cloudprovider.Provider attestation variant.Variant }{ {cloudprovider.Azure, variant.AWSNitroTPM{}}, {cloudprovider.AWS, variant.AzureTrustedLaunch{}}, {cloudprovider.GCP, variant.AWSNitroTPM{}}, {cloudprovider.QEMU, variant.GCPSEVES{}}, {cloudprovider.OpenStack, variant.AWSNitroTPM{}}, } for _, test := range tests { t.Run("", func(t *testing.T) { _, err := createConfigWithAttestationVariant(test.provider, "", test.attestation) assert.Error(err) }) } } func TestValidProviderAttestationCombination(t *testing.T) { defaultAttestation := config.Default().Attestation tests := []struct { provider cloudprovider.Provider attestation variant.Variant expected config.AttestationConfig }{ { cloudprovider.Azure, variant.AzureTrustedLaunch{}, config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch}, }, { cloudprovider.Azure, variant.AzureSEVSNP{}, config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP}, }, { cloudprovider.AWS, variant.AWSSEVSNP{}, config.AttestationConfig{AWSSEVSNP: defaultAttestation.AWSSEVSNP}, }, { cloudprovider.AWS, variant.AWSNitroTPM{}, config.AttestationConfig{AWSNitroTPM: defaultAttestation.AWSNitroTPM}, }, { cloudprovider.GCP, variant.GCPSEVES{}, config.AttestationConfig{GCPSEVES: defaultAttestation.GCPSEVES}, }, { cloudprovider.QEMU, variant.QEMUVTPM{}, config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM}, }, { cloudprovider.OpenStack, variant.QEMUVTPM{}, config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM}, }, } for _, test := range tests { t.Run(fmt.Sprintf("Provider:%s,Attestation:%s", test.provider, test.attestation), func(t *testing.T) { sut, err := createConfigWithAttestationVariant(test.provider, "", test.attestation) assert := assert.New(t) assert.NoError(err) assert.Equal(test.expected, sut.Attestation) }) } } func TestParseAttestationFlag(t *testing.T) { testCases := map[string]struct { wantErr bool attestationFlag string wantVariant variant.Variant }{ "invalid": { wantErr: true, attestationFlag: "unknown", }, "AzureTrustedLaunch": { attestationFlag: "azure-trustedlaunch", wantVariant: variant.AzureTrustedLaunch{}, }, "AzureSEVSNP": { attestationFlag: "azure-sev-snp", wantVariant: variant.AzureSEVSNP{}, }, "AWSSEVSNP": { attestationFlag: "aws-sev-snp", wantVariant: variant.AWSSEVSNP{}, }, "AWSNitroTPM": { attestationFlag: "aws-nitro-tpm", wantVariant: variant.AWSNitroTPM{}, }, "GCPSEVES": { attestationFlag: "gcp-sev-es", wantVariant: variant.GCPSEVES{}, }, "QEMUVTPM": { attestationFlag: "qemu-vtpm", wantVariant: variant.QEMUVTPM{}, }, "no flag": { wantVariant: variant.Dummy{}, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { require := require.New(t) assert := assert.New(t) cmd := newConfigGenerateCmd() if tc.attestationFlag != "" { require.NoError(cmd.Flags().Set("attestation", tc.attestationFlag)) } attestation, err := parseAttestationFlag(cmd.Flags()) if tc.wantErr { assert.Error(err) return } require.NoError(err) assert.True(tc.wantVariant.Equal(attestation)) }) } }