constellation/internal/config/config_test.go
2024-01-16 17:34:44 +01:00

1110 lines
32 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
import (
"context"
"errors"
"reflect"
"testing"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"gopkg.in/yaml.v3"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/semver"
"github.com/edgelesssys/constellation/v2/internal/versions"
gosemver "golang.org/x/mod/semver"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDefaultConfig(t *testing.T) {
assert := assert.New(t)
def := Default()
assert.NotNil(def)
}
func TestDefaultConfigMarshalsLatestVersion(t *testing.T) {
conf := Default()
bt, err := yaml.Marshal(conf)
require := require.New(t)
require.NoError(err)
var mp configMap
require.NoError(yaml.Unmarshal(bt, &mp))
assert := assert.New(t)
assert.Equal("latest", mp.getSEVSNPVersion("microcodeVersion"))
assert.Equal("latest", mp.getSEVSNPVersion("teeVersion"))
assert.Equal("latest", mp.getSEVSNPVersion("snpVersion"))
assert.Equal("latest", mp.getSEVSNPVersion("bootloaderVersion"))
}
func TestGetAttestationConfigMarshalsNumericalVersion(t *testing.T) {
conf := Default()
conf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
attestationCfg := conf.GetAttestationConfig()
bt, err := yaml.Marshal(attestationCfg)
require := require.New(t)
require.NoError(err)
var mp map[string]interface{}
require.NoError(yaml.Unmarshal(bt, &mp))
assert := assert.New(t)
assert.Equal(placeholderVersionValue, mp["microcodeVersion"])
assert.Equal(placeholderVersionValue, mp["teeVersion"])
assert.Equal(placeholderVersionValue, mp["snpVersion"])
assert.Equal(placeholderVersionValue, mp["bootloaderVersion"])
}
func TestNew(t *testing.T) {
testCases := map[string]struct {
config configMap
configName string
wantResult *Config
wantErr bool
wantedErrType error
}{
"Azure SEV-SNP: mix of Latest and uint as version value in file correctly sets latest versions values": {
config: func() configMap {
conf := Default() // default configures latest version
modifyConfigForAzureToPassValidate(conf)
m := getConfigAsMap(conf, t)
m.setSEVSNPVersion("microcodeVersion", "Latest") // check uppercase also works
m.setSEVSNPVersion("teeVersion", 2)
m.setSEVSNPVersion("bootloaderVersion", 1)
return m
}(),
configName: constants.ConfigFilename,
wantResult: func() *Config {
conf := Default()
modifyConfigForAzureToPassValidate(conf)
conf.Attestation.AzureSEVSNP.MicrocodeVersion = AttestationVersion{
Value: testCfg.Microcode,
WantLatest: true,
}
conf.Attestation.AzureSEVSNP.TEEVersion = AttestationVersion{
Value: 2,
WantLatest: false,
}
conf.Attestation.AzureSEVSNP.BootloaderVersion = AttestationVersion{
Value: 1,
WantLatest: false,
}
conf.Attestation.AzureSEVSNP.SNPVersion = AttestationVersion{
Value: testCfg.SNP,
WantLatest: true,
}
return conf
}(),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.config != nil {
require.NoError(fileHandler.WriteYAML(tc.configName, tc.config, file.OptNone))
}
result, err := New(fileHandler, tc.configName, stubAttestationFetcher{}, false)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantResult, result)
})
}
}
func modifyConfigForAzureToPassValidate(c *Config) {
c.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
c.Image = constants.BinaryVersion().String()
c.Provider.Azure.SubscriptionID = "11111111-1111-1111-1111-111111111111"
c.Provider.Azure.TenantID = "11111111-1111-1111-1111-111111111111"
c.Provider.Azure.Location = "westus"
c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/11111111-1111-1111-1111-111111111111/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Attestation.AzureSEVSNP.Measurements = measurements.M{
0: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
}
c.NodeGroups = map[string]NodeGroup{
constants.ControlPlaneDefault: {
Role: "control-plane",
Zone: "",
InstanceType: "Standard_DC4as_v5",
StateDiskSizeGB: 30,
StateDiskType: "StandardSSD_LRS",
InitialCount: 3,
},
constants.WorkerDefault: {
Role: "worker",
Zone: "",
InstanceType: "Standard_DC4as_v5",
StateDiskSizeGB: 30,
StateDiskType: "StandardSSD_LRS",
InitialCount: 3,
},
}
}
func TestReadConfigFile(t *testing.T) {
testCases := map[string]struct {
config configMap
configName string
wantResult *Config
wantErr bool
wantedErrType error
}{
"refuse invalid version value": {
config: func() configMap {
conf := Default()
m := getConfigAsMap(conf, t)
m.setSEVSNPVersion("microcodeVersion", "1a")
return m
}(),
configName: constants.ConfigFilename,
wantErr: true,
},
"outdated k8s patch version is allowed": {
config: func() configMap {
conf := Default()
ver, err := semver.New(versions.SupportedK8sVersions()[0])
require.NoError(t, err)
conf.KubernetesVersion = versions.ValidK8sVersion(semver.NewFromInt(ver.Major(), ver.Minor(), ver.Patch()-1, "").String())
m := getConfigAsMap(conf, t)
return m
}(),
wantResult: func() *Config {
conf := Default()
ver, err := semver.New(versions.SupportedK8sVersions()[0])
require.NoError(t, err)
conf.KubernetesVersion = versions.ValidK8sVersion(semver.NewFromInt(ver.Major(), ver.Minor(), ver.Patch()-1, "").String())
return conf
}(),
configName: constants.ConfigFilename,
},
"outdated k8s version is not allowed": {
config: func() configMap {
conf := Default()
conf.KubernetesVersion = versions.ValidK8sVersion("v1.0.0")
m := getConfigAsMap(conf, t)
return m
}(),
wantErr: true,
configName: constants.ConfigFilename,
},
"a k8s version without specified patch is not allowed": {
config: func() configMap {
conf := Default()
conf.KubernetesVersion = versions.ValidK8sVersion(gosemver.MajorMinor(string(versions.Default)))
m := getConfigAsMap(conf, t)
return m
}(),
wantErr: true,
configName: constants.ConfigFilename,
},
"error on entering app client id": {
config: func() configMap {
conf := Default()
m := getConfigAsMap(conf, t)
m.setAzureProvider("appClientID", "3ea4bdc1-1cc1-4237-ae78-0831eff3491e")
return m
}(),
configName: constants.ConfigFilename,
wantedErrType: &UnsupportedAppRegistrationError{},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.config != nil {
require.NoError(fileHandler.WriteYAML(tc.configName, tc.config, file.OptNone))
}
result, err := fromFile(fileHandler, tc.configName)
if tc.wantedErrType != nil {
assert.ErrorIs(err, tc.wantedErrType)
return
}
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantResult, result)
})
}
}
func TestFromFile(t *testing.T) {
testCases := map[string]struct {
config *Config
configName string
wantResult *Config
wantErr bool
}{
"default config from default file": {
config: Default(),
configName: constants.ConfigFilename,
wantResult: Default(),
},
"default config from different path": {
config: Default(),
configName: "other-config.yaml",
wantResult: Default(),
},
"default config when path empty": {
config: nil,
configName: "",
wantErr: true,
},
"err when path not exist": {
config: nil,
configName: "wrong-name.yaml",
wantErr: true,
},
"modify default config": {
config: func() *Config {
conf := Default()
conf.Provider.GCP.Region = "eu-north1"
conf.Provider.GCP.Zone = "eu-north1-a"
return conf
}(),
configName: constants.ConfigFilename,
wantResult: func() *Config {
conf := Default()
conf.Provider.GCP.Region = "eu-north1"
conf.Provider.GCP.Zone = "eu-north1-a"
return conf
}(),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.config != nil {
require.NoError(fileHandler.WriteYAML(tc.configName, tc.config, file.OptNone))
}
result, err := fromFile(fileHandler, tc.configName)
if tc.wantErr {
assert.Error(err)
} else {
require.NoError(err)
assert.Equal(tc.wantResult, result)
}
})
}
}
func TestValidate(t *testing.T) {
const defaultErrCount = 37 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
const azErrCount = 7
const awsErrCount = 8
const gcpErrCount = 8
// TODO(AB#3132,3u13r): refactor config validation tests
// Note that the `cnf.Image = ""` is a hack to align `bazel test` with `go test` behavior
// since first does version stamping.
testCases := map[string]struct {
cnf *Config
wantErr bool
wantErrCount int
}{
"default config is not valid": {
cnf: func() *Config {
cnf := Default()
cnf.Image = ""
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount,
},
"microservices violate version drift": {
cnf: func() *Config {
cnf := Default()
cnf.Image = ""
cliVersion := constants.BinaryVersion()
cnf.MicroserviceVersion = semver.NewFromInt(cliVersion.Major()+2, cliVersion.Minor(), cliVersion.Patch(), "")
return cnf
}(),
wantErr: true,
// This is a very different value from the other error counts because of the way we are checking MicroserviceVersions.
wantErrCount: 1,
},
"v0 is one error": {
cnf: func() *Config {
cnf := Default()
cnf.Image = ""
cnf.Version = "v0"
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 1,
},
"v0 and long name are two errors": {
cnf: func() *Config {
cnf := Default()
cnf.Image = ""
cnf.Version = "v0"
cnf.Name = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 2,
},
"default Azure config is not valid": {
cnf: func() *Config {
cnf := Default()
cnf.Image = ""
cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cnf
}(),
wantErr: true,
wantErrCount: azErrCount,
},
"Azure config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
cnf.Image = constants.BinaryVersion().String()
modifyConfigForAzureToPassValidate(cnf)
return cnf
}(),
},
"default AWS config is not valid": {
cnf: func() *Config {
cnf := Default()
cnf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
return cnf
}(),
wantErr: true,
wantErrCount: awsErrCount,
},
"AWS config with upper case name": {
cnf: func() *Config {
cnf := Default()
cnf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
cnf.Name = "testAWS"
return cnf
}(),
wantErr: true,
wantErrCount: awsErrCount + 1,
},
"AWS config with correct region and zone format": {
cnf: func() *Config {
cnf := Default()
cnf.Provider.AWS.Region = "us-east-2"
cnf.Provider.AWS.Zone = "us-east-2a"
cnf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
return cnf
}(),
wantErr: true,
wantErrCount: awsErrCount - 4,
},
"AWS config with wrong region and zone format": {
cnf: func() *Config {
cnf := Default()
cnf.Provider.AWS.Region = "us-west2"
cnf.Provider.AWS.Zone = "a"
cnf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
return cnf
}(),
wantErr: true,
wantErrCount: awsErrCount,
},
"default GCP config is not valid": {
cnf: func() *Config {
cnf := Default()
cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
cnf.Image = ""
return cnf
}(),
wantErr: true,
wantErrCount: gcpErrCount,
},
"GCP config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
cnf.Image = constants.BinaryVersion().String()
gcp := cnf.Provider.GCP
gcp.Region = "test-region"
gcp.Project = "test-project"
gcp.Zone = "test-zone"
gcp.ServiceAccountKeyPath = "test-key-path"
cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp
cnf.Attestation.GCPSEVES.Measurements = measurements.M{
0: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
}
cnf.NodeGroups = map[string]NodeGroup{
constants.ControlPlaneDefault: {
Role: "control-plane",
Zone: "europe-west1-b",
InstanceType: "n2d-standard-4",
StateDiskSizeGB: 30,
StateDiskType: "pd-ssd",
InitialCount: 3,
},
constants.WorkerDefault: {
Role: "worker",
Zone: "europe-west1-b",
InstanceType: "n2d-standard-4",
StateDiskSizeGB: 30,
StateDiskType: "pd-ssd",
InitialCount: 3,
},
}
return cnf
}(),
},
"miniup default config is not valid because image and measurements are missing in OSS": {
cnf: func() *Config {
config, _ := MiniDefault()
require.NotNil(t, config)
return config
}(),
wantErr: true,
wantErrCount: 2,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
err := tc.cnf.Validate(false)
if tc.wantErr {
assert.Error(err)
var valErr *ValidationError
require.ErrorAs(err, &valErr)
assert.Equalf(tc.wantErrCount, valErr.messagesCount(), "Got unexpected error count: %d: %s", valErr.messagesCount(), valErr.LongMessage())
return
}
assert.NoError(err)
})
}
}
func TestHasProvider(t *testing.T) {
assert := assert.New(t)
assert.False((&Config{}).HasProvider(cloudprovider.Unknown))
assert.False((&Config{}).HasProvider(cloudprovider.Azure))
assert.False((&Config{}).HasProvider(cloudprovider.GCP))
assert.False((&Config{}).HasProvider(cloudprovider.QEMU))
assert.False(Default().HasProvider(cloudprovider.Unknown))
assert.True(Default().HasProvider(cloudprovider.Azure))
assert.True(Default().HasProvider(cloudprovider.GCP))
cnfWithAzure := Config{Provider: ProviderConfig{Azure: &AzureConfig{}}}
assert.False(cnfWithAzure.HasProvider(cloudprovider.Unknown))
assert.True(cnfWithAzure.HasProvider(cloudprovider.Azure))
assert.False(cnfWithAzure.HasProvider(cloudprovider.GCP))
}
func TestConfigRemoveProviderExcept(t *testing.T) {
testCases := map[string]struct {
removeExcept cloudprovider.Provider
wantAWS *AWSConfig
wantAzure *AzureConfig
wantGCP *GCPConfig
wantQEMU *QEMUConfig
}{
"except aws": {
removeExcept: cloudprovider.AWS,
wantAWS: Default().Provider.AWS,
},
"except azure": {
removeExcept: cloudprovider.Azure,
wantAzure: Default().Provider.Azure,
},
"except gcp": {
removeExcept: cloudprovider.GCP,
wantGCP: Default().Provider.GCP,
},
"except qemu": {
removeExcept: cloudprovider.QEMU,
wantQEMU: Default().Provider.QEMU,
},
"unknown provider": {
removeExcept: cloudprovider.Unknown,
wantAWS: Default().Provider.AWS,
wantAzure: Default().Provider.Azure,
wantGCP: Default().Provider.GCP,
wantQEMU: Default().Provider.QEMU,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
conf := Default()
conf.RemoveProviderAndAttestationExcept(tc.removeExcept)
assert.Equal(tc.wantAWS, conf.Provider.AWS)
assert.Equal(tc.wantAzure, conf.Provider.Azure)
assert.Equal(tc.wantGCP, conf.Provider.GCP)
assert.Equal(tc.wantQEMU, conf.Provider.QEMU)
})
}
}
func TestConfigGeneratedDocsFresh(t *testing.T) {
assert := assert.New(t)
updateMsg := "remember to re-generate config docs! 🔨"
assert.Len(ConfigDoc.Fields, reflect.ValueOf(Config{}).NumField(), updateMsg)
assert.Len(ProviderConfigDoc.Fields, reflect.ValueOf(ProviderConfig{}).NumField(), updateMsg)
assert.Len(AWSConfigDoc.Fields, reflect.ValueOf(AWSConfig{}).NumField(), updateMsg)
assert.Len(AzureConfigDoc.Fields, reflect.ValueOf(AzureConfig{}).NumField(), updateMsg)
assert.Len(GCPConfigDoc.Fields, reflect.ValueOf(GCPConfig{}).NumField(), updateMsg)
assert.Len(QEMUConfigDoc.Fields, reflect.ValueOf(QEMUConfig{}).NumField(), updateMsg)
}
func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t)
newMeasurements := measurements.M{
1: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
2: measurements.WithAllBytes(0x01, measurements.Enforce, measurements.PCRMeasurementLength),
3: measurements.WithAllBytes(0x02, measurements.Enforce, measurements.PCRMeasurementLength),
}
{ // AWS
conf := Default()
conf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
for k := range conf.Attestation.AWSSEVSNP.Measurements {
delete(conf.Attestation.AWSSEVSNP.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Attestation.AWSSEVSNP.Measurements)
}
{ // Azure
conf := Default()
conf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
for k := range conf.Attestation.AzureSEVSNP.Measurements {
delete(conf.Attestation.AzureSEVSNP.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Attestation.AzureSEVSNP.Measurements)
}
{ // GCP
conf := Default()
conf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
for k := range conf.Attestation.GCPSEVES.Measurements {
delete(conf.Attestation.GCPSEVES.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Attestation.GCPSEVES.Measurements)
}
{ // QEMU
conf := Default()
conf.RemoveProviderAndAttestationExcept(cloudprovider.QEMU)
for k := range conf.Attestation.QEMUVTPM.Measurements {
delete(conf.Attestation.QEMUVTPM.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Attestation.QEMUVTPM.Measurements)
}
}
func TestConfig_IsReleaseImage(t *testing.T) {
testCases := map[string]struct {
conf *Config
want bool
}{
"release image v0.0.0": {
conf: func() *Config {
conf := Default()
conf.Image = "v0.0.0"
return conf
}(),
want: true,
},
"branch image": {
conf: func() *Config {
conf := Default()
conf.Image = "feat-x-vX.Y.Z-pre.0.yyyymmddhhmmss-abcdefabcdef"
return conf
}(),
want: false,
},
"debug image": {
conf: func() *Config {
conf := Default()
conf.Image = "debug-vX.Y.Z-pre.0.yyyymmddhhmmss-abcdefabcdef"
return conf
}(),
want: false,
},
"empty config": {
conf: &Config{},
want: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
assert.Equal(tc.want, tc.conf.IsReleaseImage())
})
}
}
func TestValidInstanceTypeForProvider(t *testing.T) {
testCases := map[string]struct {
provider cloudprovider.Provider
instanceTypes []string
nonCVMsAllowed bool
expectedResult bool
}{
"empty all": {
provider: cloudprovider.Unknown,
instanceTypes: []string{},
expectedResult: false,
},
"empty aws": {
provider: cloudprovider.AWS,
instanceTypes: []string{},
expectedResult: false,
},
"empty azure only CVMs": {
provider: cloudprovider.Azure,
instanceTypes: []string{},
expectedResult: false,
},
"empty azure with non-CVMs": {
provider: cloudprovider.Azure,
instanceTypes: []string{},
nonCVMsAllowed: true,
expectedResult: false,
},
"empty gcp": {
provider: cloudprovider.GCP,
instanceTypes: []string{},
expectedResult: false,
},
"azure only CVMs (SNP)": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureSNPInstanceTypes,
expectedResult: true,
},
"azure only CVMs (TDX)": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureTDXInstanceTypes,
expectedResult: true,
},
"azure CVMs but CVMs disabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureSNPInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
"azure trusted launch VMs with CVMs enabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
expectedResult: false,
},
"azure trusted launch VMs with CVMs disabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
nonCVMsAllowed: true,
expectedResult: true,
},
"gcp": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.GCPInstanceTypes,
expectedResult: true,
},
"put gcp when azure is set": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.GCPInstanceTypes,
expectedResult: false,
},
"put gcp when azure is set with CVMs disabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.GCPInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
"put azure when gcp is set": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.AzureSNPInstanceTypes,
expectedResult: false,
},
"put azure when gcp is set with CVMs disabled": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
// Testing every possible instance type for AWS is not feasible, so we just test a few based on known supported / unsupported families
// Also serves as a test for checkIfInstanceInValidAWSFamilys
"aws two valid instances": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c5.xlarge", "c5a.2xlarge", "c5a.16xlarge", "u-12tb1.112xlarge"},
expectedResult: true,
nonCVMsAllowed: true,
},
"aws one valid instance one with too little vCPUs": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c5.medium"},
expectedResult: false,
nonCVMsAllowed: true,
},
"aws graviton sub-family unsupported": {
provider: cloudprovider.AWS,
instanceTypes: []string{"m6g.xlarge", "r6g.2xlarge", "x2gd.xlarge", "g5g.8xlarge"},
expectedResult: false,
nonCVMsAllowed: true,
},
"aws combined two valid instances as one string": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c5.xlarge, c5a.2xlarge"},
expectedResult: false,
nonCVMsAllowed: true,
},
"aws only CVMs": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c6a.xlarge", "m6a.xlarge", "r6a.xlarge"},
expectedResult: true,
},
"aws CVMs but CVMs disabled": {
provider: cloudprovider.AWS,
instanceTypes: []string{"m6a.xlarge", "c6a.xlarge", "r6a.xlarge"},
nonCVMsAllowed: true,
expectedResult: true,
},
"aws nitroTPM VMs with CVMs enabled": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c5.xlarge", "c5a.2xlarge", "c5a.16xlarge", "u-12tb1.112xlarge"},
expectedResult: false,
},
"aws nitroTPM VMs with CVMs disabled": {
provider: cloudprovider.AWS,
instanceTypes: []string{"c5.xlarge", "c5a.2xlarge", "c5a.16xlarge", "u-12tb1.112xlarge"},
nonCVMsAllowed: true,
expectedResult: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
for _, instanceType := range tc.instanceTypes {
assert.Equal(tc.expectedResult, validInstanceTypeForProvider(instanceType, tc.nonCVMsAllowed, tc.provider), instanceType)
}
})
}
}
func TestIsDebugCluster(t *testing.T) {
testCases := map[string]struct {
config *Config
prepareConfig func(*Config)
expectedResult bool
}{
"empty config": {
config: &Config{},
expectedResult: false,
},
"default config": {
config: Default(),
expectedResult: false,
},
"enabled": {
config: Default(),
prepareConfig: func(conf *Config) {
*conf.DebugCluster = true
},
expectedResult: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if tc.prepareConfig != nil {
tc.prepareConfig(tc.config)
}
assert.Equal(tc.expectedResult, tc.config.IsDebugCluster())
})
}
}
func TestValidateProvider(t *testing.T) {
testCases := map[string]struct {
provider ProviderConfig
wantErr bool
expectedErrorTag string
}{
"empty, should trigger no provider error": {
provider: ProviderConfig{},
wantErr: true,
expectedErrorTag: "no_provider",
},
"azure only, should be okay": {
provider: ProviderConfig{
Azure: &AzureConfig{},
},
wantErr: false,
},
"gcp only, should be okay": {
provider: ProviderConfig{
GCP: &GCPConfig{},
},
wantErr: false,
},
"qemu only, should be okay": {
provider: ProviderConfig{
QEMU: &QEMUConfig{},
},
wantErr: false,
},
"azure and gcp, should trigger multiple provider error": {
provider: ProviderConfig{
Azure: &AzureConfig{},
GCP: &GCPConfig{},
},
wantErr: true,
expectedErrorTag: "more_than_one_provider",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
v := validator.New()
trans := ut.New(en.New()).GetFallback()
conf := Default()
conf.Provider = tc.provider
v.RegisterStructValidation(validateProvider, ProviderConfig{})
err := v.StructPartial(tc.provider)
// Register provider validation error types.
// Make sure the tags and expected strings below are in sync with the actual implementation.
require.NoError(v.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError))
require.NoError(v.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, conf.translateMoreThanOneProviderError))
// Continue if no error is expected.
if !tc.wantErr {
assert.NoError(err)
return
}
// Validate if the error was identified correctly.
require.NotNil(err)
assert.Error(err)
assert.Contains(err.Error(), tc.expectedErrorTag)
// Check if error translation works correctly.
validationErr := err.(validator.ValidationErrors)
translatedErr := validationErr.Translate(trans)
// The translator does not seem to export a list of available translations or for a specific field.
// So we need to hardcode expected strings. Needs to be in sync with implementation.
switch tc.expectedErrorTag {
case "no_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "No provider has been defined")
case "more_than_one_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "Only one provider can be defined")
}
})
}
}
func TestConfigVersionCompatibility(t *testing.T) {
t.Skip() // TODO(daniel-weisse): re-enable and re-write for config v4
testCases := map[string]struct {
config string
expectedConfig *Config
}{
"config v2 gcp": {
config: "testdata/configGCPV2.yaml",
expectedConfig: &Config{
Version: "v2",
Image: "v2.5.0",
KubernetesVersion: "1.23",
DebugCluster: toPtr(false),
Provider: ProviderConfig{
GCP: &GCPConfig{
Project: "project-12345",
Region: "europe-west3",
Zone: "europe-west3-b",
ServiceAccountKeyPath: "serviceAccountKey.json",
DeployCSIDriver: toPtr(true),
},
},
},
},
"config v2 aws": {
config: "testdata/configAWSV2.yaml",
expectedConfig: &Config{
Version: "v2",
Image: "v2.5.0",
KubernetesVersion: "1.23",
DebugCluster: toPtr(false),
Provider: ProviderConfig{
AWS: &AWSConfig{
Region: "us-east-2",
Zone: "us-east-2a",
IAMProfileControlPlane: "control_plane_instance_profile_name",
IAMProfileWorkerNodes: "node_instance_profile",
},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := fromFile(fileHandler, tc.config)
assert.NoError(err)
assert.Equal(tc.expectedConfig, config)
})
}
}
func TestIsDebugImage(t *testing.T) {
cases := map[string]struct {
image string
expected bool
}{
"debug image": {"ref/test/stream/debug/v2.9.0-pre.0.20230613084544-eeea7b1f56f4", true},
"release image": {"v2.8.0", false},
"empty image": {"", false},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
c := &Config{Image: tc.image}
assert.Equal(t, tc.expected, c.IsNamedLikeDebugImage())
})
}
}
func TestIsAppClientIDError(t *testing.T) {
testCases := map[string]struct {
err error
expected bool
}{
"yaml.Error with appClientID error": {
err: &yaml.TypeError{
Errors: []string{
"invalid value for appClientID",
"another error",
},
},
expected: true,
},
"yaml.Error without appClientID error": {
err: &yaml.TypeError{
Errors: []string{
"invalid value for something else",
"another error",
},
},
expected: false,
},
"other error": {
err: errors.New("appClientID but other error type"),
expected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
assert.Equal(tc.expected, isAppClientIDError(tc.err))
})
}
}
// configMap is used to un-/marshal the config as an unstructured map.
type configMap map[string]interface{}
func (c configMap) setSEVSNPVersion(versionType string, value interface{}) {
c["attestation"].(configMap)["azureSEVSNP"].(configMap)[versionType] = value
}
func (c configMap) setAzureProvider(azureProviderField string, value interface{}) {
c["provider"].(configMap)["azure"].(configMap)[azureProviderField] = value
}
func (c configMap) getSEVSNPVersion(versionType string) interface{} {
return c["attestation"].(configMap)["azureSEVSNP"].(configMap)[versionType]
}
// getConfigAsMap returns a map of the config.
func getConfigAsMap(conf *Config, t *testing.T) (res configMap) {
bytes, err := yaml.Marshal(&conf)
if err != nil {
t.Fatal(err)
}
if err := yaml.Unmarshal(bytes, &res); err != nil {
t.Fatal(err)
}
return
}
type stubAttestationFetcher struct{}
func (f stubAttestationFetcher) FetchSEVSNPVersionList(_ context.Context, _ attestationconfigapi.SEVSNPVersionList) (attestationconfigapi.SEVSNPVersionList, error) {
return attestationconfigapi.SEVSNPVersionList{}, nil
}
func (f stubAttestationFetcher) FetchSEVSNPVersion(_ context.Context, _ attestationconfigapi.SEVSNPVersionAPI) (attestationconfigapi.SEVSNPVersionAPI, error) {
return attestationconfigapi.SEVSNPVersionAPI{
SEVSNPVersion: testCfg,
}, nil
}
func (f stubAttestationFetcher) FetchSEVSNPVersionLatest(_ context.Context, _ variant.Variant) (attestationconfigapi.SEVSNPVersionAPI, error) {
return attestationconfigapi.SEVSNPVersionAPI{
SEVSNPVersion: testCfg,
}, nil
}
var testCfg = attestationconfigapi.SEVSNPVersion{
Microcode: 93,
TEE: 0,
SNP: 6,
Bootloader: 2,
}