constellation/internal/config/config_test.go

595 lines
16 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
import (
"reflect"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config/instancetypes"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
const defaultMsgCount = 13 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDefaultConfig(t *testing.T) {
assert := assert.New(t)
def := Default()
assert.NotNil(def)
}
func TestFromFile(t *testing.T) {
testCases := map[string]struct {
config *Config
configName string
wantResult *Config
wantErr bool
}{
"default config from default file": {
config: Default(),
configName: constants.ConfigFilename,
wantResult: Default(),
},
"default config from different path": {
config: Default(),
configName: "other-config.yaml",
wantResult: Default(),
},
"default config when path empty": {
config: nil,
configName: "",
wantErr: true,
},
"err when path not exist": {
config: nil,
configName: "wrong-name.yaml",
wantErr: true,
},
"custom config from default file": {
config: &Config{
Version: Version1,
AutoscalingNodeGroupMin: 42,
AutoscalingNodeGroupMax: 1337,
},
configName: constants.ConfigFilename,
wantResult: &Config{
Version: Version1,
AutoscalingNodeGroupMin: 42,
AutoscalingNodeGroupMax: 1337,
},
},
"modify default config": {
config: func() *Config {
conf := Default()
conf.Provider.GCP.Region = "eu-north1"
conf.Provider.GCP.Zone = "eu-north1-a"
return conf
}(),
configName: constants.ConfigFilename,
wantResult: func() *Config {
conf := Default()
conf.Provider.GCP.Region = "eu-north1"
conf.Provider.GCP.Zone = "eu-north1-a"
return conf
}(),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
if tc.config != nil {
require.NoError(fileHandler.WriteYAML(tc.configName, tc.config, file.OptNone))
}
result, err := FromFile(fileHandler, tc.configName)
if tc.wantErr {
assert.Error(err)
} else {
require.NoError(err)
assert.Equal(tc.wantResult, result)
}
})
}
}
func TestFromFileStrictErrors(t *testing.T) {
testCases := map[string]struct {
yamlConfig string
wantErr bool
}{
"valid config": {
yamlConfig: `
autoscalingNodeGroupMin: 5
autoscalingNodeGroupMax: 10
stateDisksizeGB: 25
`,
},
"typo": {
yamlConfig: `
autoscalingNodeGroupMini: 5
autoscalingNodeGroupMax: 10
stateDisksizeGB: 25
`,
wantErr: true,
},
"unsupported version": {
yamlConfig: `
version: v5
autoscalingNodeGroupMin: 1
autoscalingNodeGroupMax: 10
stateDisksizeGB: 30
`,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
err := fileHandler.Write(constants.ConfigFilename, []byte(tc.yamlConfig), file.OptNone)
assert.NoError(err)
_, err = FromFile(fileHandler, constants.ConfigFilename)
assert.Error(err)
})
}
}
func TestValidate(t *testing.T) {
testCases := map[string]struct {
cnf *Config
wantMsgCount int
}{
"default config is valid": {
cnf: Default(),
wantMsgCount: defaultMsgCount,
},
"config with 1 error": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
return cnf
}(),
wantMsgCount: defaultMsgCount + 1,
},
"config with 2 errors": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
cnf.StateDiskSizeGB = -1
return cnf
}(),
wantMsgCount: defaultMsgCount + 2,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
msgs, err := tc.cnf.Validate()
require.NoError(err)
assert.Len(msgs, tc.wantMsgCount)
})
}
}
func TestHasProvider(t *testing.T) {
assert := assert.New(t)
assert.False((&Config{}).HasProvider(cloudprovider.Unknown))
assert.False((&Config{}).HasProvider(cloudprovider.Azure))
assert.False((&Config{}).HasProvider(cloudprovider.GCP))
assert.False((&Config{}).HasProvider(cloudprovider.QEMU))
assert.False(Default().HasProvider(cloudprovider.Unknown))
assert.True(Default().HasProvider(cloudprovider.Azure))
assert.True(Default().HasProvider(cloudprovider.GCP))
cnfWithAzure := Config{Provider: ProviderConfig{Azure: &AzureConfig{}}}
assert.False(cnfWithAzure.HasProvider(cloudprovider.Unknown))
assert.True(cnfWithAzure.HasProvider(cloudprovider.Azure))
assert.False(cnfWithAzure.HasProvider(cloudprovider.GCP))
}
func TestImage(t *testing.T) {
testCases := map[string]struct {
cfg *Config
wantImage string
}{
"default azure": {
cfg: func() *Config { c := Default(); c.RemoveProviderExcept(cloudprovider.Azure); return c }(),
wantImage: Default().Provider.Azure.Image,
},
"default gcp": {
cfg: func() *Config { c := Default(); c.RemoveProviderExcept(cloudprovider.GCP); return c }(),
wantImage: Default().Provider.GCP.Image,
},
"default qemu": {
cfg: func() *Config { c := Default(); c.RemoveProviderExcept(cloudprovider.QEMU); return c }(),
wantImage: "",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
image := tc.cfg.Image()
assert.Equal(tc.wantImage, image)
})
}
}
func TestConfigRemoveProviderExcept(t *testing.T) {
testCases := map[string]struct {
removeExcept cloudprovider.Provider
wantAzure *AzureConfig
wantGCP *GCPConfig
wantQEMU *QEMUConfig
}{
"except azure": {
removeExcept: cloudprovider.Azure,
wantAzure: Default().Provider.Azure,
},
"except gcp": {
removeExcept: cloudprovider.GCP,
wantGCP: Default().Provider.GCP,
},
"except qemu": {
removeExcept: cloudprovider.QEMU,
wantQEMU: Default().Provider.QEMU,
},
"unknown provider": {
removeExcept: cloudprovider.Unknown,
wantAzure: Default().Provider.Azure,
wantGCP: Default().Provider.GCP,
wantQEMU: Default().Provider.QEMU,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
conf := Default()
conf.RemoveProviderExcept(tc.removeExcept)
assert.Equal(tc.wantAzure, conf.Provider.Azure)
assert.Equal(tc.wantGCP, conf.Provider.GCP)
assert.Equal(tc.wantQEMU, conf.Provider.QEMU)
})
}
}
func TestConfigGeneratedDocsFresh(t *testing.T) {
assert := assert.New(t)
updateMsg := "remember to re-generate config docs! 🔨"
assert.Len(ConfigDoc.Fields, reflect.ValueOf(Config{}).NumField(), updateMsg)
assert.Len(UpgradeConfigDoc.Fields, reflect.ValueOf(UpgradeConfig{}).NumField(), updateMsg)
assert.Len(UserKeyDoc.Fields, reflect.ValueOf(UserKey{}).NumField(), updateMsg)
assert.Len(ProviderConfigDoc.Fields, reflect.ValueOf(ProviderConfig{}).NumField(), updateMsg)
assert.Len(AzureConfigDoc.Fields, reflect.ValueOf(AzureConfig{}).NumField(), updateMsg)
assert.Len(GCPConfigDoc.Fields, reflect.ValueOf(GCPConfig{}).NumField(), updateMsg)
assert.Len(QEMUConfigDoc.Fields, reflect.ValueOf(QEMUConfig{}).NumField(), updateMsg)
}
func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t)
newMeasurements := Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
}
{ // Azure
conf := Default()
conf.RemoveProviderExcept(cloudprovider.Azure)
for k := range conf.Provider.Azure.Measurements {
delete(conf.Provider.Azure.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Provider.Azure.Measurements)
}
{ // GCP
conf := Default()
conf.RemoveProviderExcept(cloudprovider.GCP)
for k := range conf.Provider.GCP.Measurements {
delete(conf.Provider.GCP.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Provider.GCP.Measurements)
}
{ // QEMU
conf := Default()
conf.RemoveProviderExcept(cloudprovider.QEMU)
for k := range conf.Provider.QEMU.Measurements {
delete(conf.Provider.QEMU.Measurements, k)
}
conf.UpdateMeasurements(newMeasurements)
assert.Equal(newMeasurements, conf.Provider.QEMU.Measurements)
}
}
func TestConfig_IsImageDebug(t *testing.T) {
testCases := map[string]struct {
conf *Config
want bool
}{
"gcp release": {
conf: func() *Config {
conf := Default()
conf.RemoveProviderExcept(cloudprovider.GCP)
conf.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-v1-3-0"
return conf
}(),
want: false,
},
"gcp debug": {
conf: func() *Config {
conf := Default()
conf.RemoveProviderExcept(cloudprovider.GCP)
conf.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-20220812102023"
return conf
}(),
want: true,
},
"azure release": {
conf: func() *Config {
conf := Default()
conf.RemoveProviderExcept(cloudprovider.Azure)
conf.Provider.Azure.Image = "/CommunityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/Images/constellation/Versions/0.0.1"
return conf
}(),
want: false,
},
"azure debug": {
conf: func() *Config {
conf := Default()
conf.RemoveProviderExcept(cloudprovider.Azure)
conf.Provider.Azure.Image = "/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/constellation-images/providers/Microsoft.Compute/galleries/Constellation_Debug/images/v1.4.0/versions/2022.0805.151600"
return conf
}(),
want: true,
},
"empty config": {
conf: &Config{},
want: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
assert.Equal(tc.want, tc.conf.IsDebugImage())
})
}
}
func TestValidInstanceTypeForProvider(t *testing.T) {
testCases := map[string]struct {
provider cloudprovider.Provider
instanceTypes []string
nonCVMsAllowed bool
expectedResult bool
}{
"empty all": {
provider: cloudprovider.Unknown,
instanceTypes: []string{},
expectedResult: false,
},
"empty azure only CVMs": {
provider: cloudprovider.Azure,
instanceTypes: []string{},
expectedResult: false,
},
"empty azure with non-CVMs": {
provider: cloudprovider.Azure,
instanceTypes: []string{},
nonCVMsAllowed: true,
expectedResult: false,
},
"empty gcp": {
provider: cloudprovider.GCP,
instanceTypes: []string{},
expectedResult: false,
},
"azure only CVMs": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureCVMInstanceTypes,
expectedResult: true,
},
"azure CVMs but CVMs disabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureCVMInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
"azure trusted launch VMs with CVMs enabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
expectedResult: false,
},
"azure trusted launch VMs with CVMs disbled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
nonCVMsAllowed: true,
expectedResult: true,
},
"gcp": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.GCPInstanceTypes,
expectedResult: true,
},
"put gcp when azure is set": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.GCPInstanceTypes,
expectedResult: false,
},
"put gcp when azure is set with CVMs disabled": {
provider: cloudprovider.Azure,
instanceTypes: instancetypes.GCPInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
"put azure when gcp is set": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.AzureCVMInstanceTypes,
expectedResult: false,
},
"put azure when gcp is set with CVMs disabled": {
provider: cloudprovider.GCP,
instanceTypes: instancetypes.AzureTrustedLaunchInstanceTypes,
nonCVMsAllowed: true,
expectedResult: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
for _, instanceType := range tc.instanceTypes {
assert.Equal(tc.expectedResult, validInstanceTypeForProvider(instanceType, tc.nonCVMsAllowed, tc.provider), instanceType)
}
})
}
}
func TestIsDebugCluster(t *testing.T) {
testCases := map[string]struct {
config *Config
prepareConfig func(*Config)
expectedResult bool
}{
"empty config": {
config: &Config{},
expectedResult: false,
},
"default config": {
config: Default(),
expectedResult: false,
},
"enabled": {
config: Default(),
prepareConfig: func(conf *Config) {
*conf.DebugCluster = true
},
expectedResult: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if tc.prepareConfig != nil {
tc.prepareConfig(tc.config)
}
assert.Equal(tc.expectedResult, tc.config.IsDebugCluster())
})
}
}
func TestValidateProvider(t *testing.T) {
testCases := map[string]struct {
provider ProviderConfig
wantErr bool
expectedErrorTag string
}{
"empty, should trigger no provider error": {
provider: ProviderConfig{},
wantErr: true,
expectedErrorTag: "no_provider",
},
"azure only, should be okay": {
provider: ProviderConfig{
Azure: &AzureConfig{},
},
wantErr: false,
},
"gcp only, should be okay": {
provider: ProviderConfig{
GCP: &GCPConfig{},
},
wantErr: false,
},
"qemu only, should be okay": {
provider: ProviderConfig{
QEMU: &QEMUConfig{},
},
wantErr: false,
},
"azure and gcp, should trigger multiple provider error": {
provider: ProviderConfig{
Azure: &AzureConfig{},
GCP: &GCPConfig{},
},
wantErr: true,
expectedErrorTag: "more_than_one_provider",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
v := validator.New()
trans := ut.New(en.New()).GetFallback()
conf := Default()
conf.Provider = tc.provider
v.RegisterStructValidation(validateProvider, ProviderConfig{})
err := v.StructPartial(tc.provider)
// Register provider validation error types.
// Make sure the tags and expected strings below are in sync with the actual implementation.
require.NoError(v.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError))
require.NoError(v.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, conf.translateMoreThanOneProviderError))
// Continue if no error is expected.
if !tc.wantErr {
assert.NoError(err)
return
}
// Validate if the error was identified correctly.
require.NotNil(err)
assert.Error(err)
assert.Contains(err.Error(), tc.expectedErrorTag)
// Check if error translation works correctly.
validationErr := err.(validator.ValidationErrors)
translatedErr := validationErr.Translate(trans)
// The translator does not seem to export a list of available translations or for a specific field.
// So we need to hardcode expected strings. Needs to be in sync with implementation.
switch tc.expectedErrorTag {
case "no_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "No provider has been defined")
case "more_than_one_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "Only one provider can be defined")
}
})
}
}