Add validation for zero or more than one provider

This commit is contained in:
Nils Hanke 2022-09-07 11:53:44 +02:00 committed by Nils Hanke
parent fb5faa681c
commit 7aded65ea8
2 changed files with 170 additions and 13 deletions

View File

@ -15,6 +15,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"regexp" "regexp"
"strings"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config/instancetypes" "github.com/edgelesssys/constellation/internal/config/instancetypes"
@ -252,6 +253,28 @@ func validateGCPInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP) return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
} }
// validateProvider checks if zero or more than one providers are defined in the config.
func validateProvider(sl validator.StructLevel) {
provider := sl.Current().Interface().(ProviderConfig)
providerCount := 0
if provider.Azure != nil {
providerCount++
}
if provider.GCP != nil {
providerCount++
}
if provider.QEMU != nil {
providerCount++
}
if providerCount < 1 {
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
} else if providerCount > 1 {
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
}
}
// Validate checks the config values and returns validation error messages. // Validate checks the config values and returns validation error messages.
// The function only returns an error if the validation itself fails. // The function only returns an error if the validation itself fails.
func (c *Config) Validate() ([]string, error) { func (c *Config) Validate() ([]string, error) {
@ -262,7 +285,7 @@ func (c *Config) Validate() ([]string, error) {
} }
// Register Azure & GCP InstanceType validation error types // Register Azure & GCP InstanceType validation error types
if err := validate.RegisterTranslation("azure_instance_type", trans, c.registerTranslateAzureInstanceTypeError, translateAzureInstanceTypeError); err != nil { if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
return nil, err return nil, err
} }
@ -270,6 +293,15 @@ func (c *Config) Validate() ([]string, error) {
return nil, err return nil, err
} }
// Register Provider validation error types
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
return nil, err
}
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
return nil, err
}
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs. // register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil { if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
return nil, err return nil, err
@ -280,11 +312,14 @@ func (c *Config) Validate() ([]string, error) {
return nil, err return nil, err
} }
// register custom validator with label azure_instance_type to validate version based on available versionConfigs. // register custom validator with label gcp_instance_type to validate version based on available versionConfigs.
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil { if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
return nil, err return nil, err
} }
// Register provider validation
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
err := validate.Struct(c) err := validate.Struct(c)
if err == nil { if err == nil {
return nil, nil return nil, nil
@ -302,18 +337,19 @@ func (c *Config) Validate() ([]string, error) {
return msgs, nil return msgs, nil
} }
// Validation translation functions for Azure & GCP instance type error functions. // Validation translation functions for Azure & GCP instance type errors.
func (c *Config) registerTranslateAzureInstanceTypeError(ut ut.Translator) error { func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
// Suggest trusted launch VMs if confidential VMs have been specifically disabled return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureTrustedLaunchInstanceTypes), true)
}
// Otherwise suggest CVMs
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureCVMInstanceTypes), true)
} }
func translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string { func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("azure_instance_type", fe.Field()) // Suggest trusted launch VMs if confidential VMs have been specifically disabled
var t string
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
} else {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
}
return t return t
} }
@ -328,6 +364,41 @@ func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) st
return t return t
} }
// Validation translation functions for Provider errors.
func registerNoProviderError(ut ut.Translator) error {
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
}
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("no_provider", fe.Field())
return t
}
func registerMoreThanOneProviderError(ut ut.Translator) error {
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
}
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
definedProviders := make([]string, 0)
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
if c.Provider.Azure != nil {
definedProviders = append(definedProviders, "Azure")
}
if c.Provider.GCP != nil {
definedProviders = append(definedProviders, "GCP")
}
if c.Provider.QEMU != nil {
definedProviders = append(definedProviders, "QEMU")
}
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
return t
}
// HasProvider checks whether the config contains the provider. // HasProvider checks whether the config contains the provider.
func (c *Config) HasProvider(provider cloudprovider.Provider) bool { func (c *Config) HasProvider(provider cloudprovider.Provider) bool {
switch provider { switch provider {

View File

@ -14,13 +14,16 @@ import (
"github.com/edgelesssys/constellation/internal/config/instancetypes" "github.com/edgelesssys/constellation/internal/config/instancetypes"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
) )
const defaultMsgCount = 12 // expect this number of error messages by default because user-specific values are not set const defaultMsgCount = 13 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m) goleak.VerifyTestMain(m)
@ -506,3 +509,86 @@ func TestIsDebugCluster(t *testing.T) {
}) })
} }
} }
func TestValidateProvider(t *testing.T) {
testCases := map[string]struct {
provider ProviderConfig
wantErr bool
expectedErrorTag string
}{
"empty, should trigger no provider error": {
provider: ProviderConfig{},
wantErr: true,
expectedErrorTag: "no_provider",
},
"azure only, should be okay": {
provider: ProviderConfig{
Azure: &AzureConfig{},
},
wantErr: false,
},
"gcp only, should be okay": {
provider: ProviderConfig{
GCP: &GCPConfig{},
},
wantErr: false,
},
"qemu only, should be okay": {
provider: ProviderConfig{
QEMU: &QEMUConfig{},
},
wantErr: false,
},
"azure and gcp, should trigger multiple provider error": {
provider: ProviderConfig{
Azure: &AzureConfig{},
GCP: &GCPConfig{},
},
wantErr: true,
expectedErrorTag: "more_than_one_provider",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
v := validator.New()
trans := ut.New(en.New()).GetFallback()
conf := Default()
conf.Provider = tc.provider
v.RegisterStructValidation(validateProvider, ProviderConfig{})
err := v.StructPartial(tc.provider)
// Register provider validation error types.
// Make sure the tags and expected strings below are in sync with the actual implementation.
require.NoError(v.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError))
require.NoError(v.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, conf.translateMoreThanOneProviderError))
// Continue if no error is expected.
if !tc.wantErr {
assert.NoError(err)
return
}
// Validate if the error was identified correctly.
require.NotNil(err)
assert.Error(err)
assert.Contains(err.Error(), tc.expectedErrorTag)
// Check if error translation works correctly.
validationErr := err.(validator.ValidationErrors)
translatedErr := validationErr.Translate(trans)
// The translator does not seem to export a list of available translations or for a specific field.
// So we need to hardcode expected strings. Needs to be in sync with implementation.
switch tc.expectedErrorTag {
case "no_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "No provider has been defined")
case "more_than_one_provider":
assert.Contains(translatedErr["ProviderConfig.Provider"], "Only one provider can be defined")
}
})
}
}