mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Add validation for zero or more than one provider
This commit is contained in:
parent
fb5faa681c
commit
7aded65ea8
@ -15,6 +15,7 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
||||
@ -252,6 +253,28 @@ func validateGCPInstanceType(fl validator.FieldLevel) bool {
|
||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
|
||||
}
|
||||
|
||||
// validateProvider checks if zero or more than one providers are defined in the config.
|
||||
func validateProvider(sl validator.StructLevel) {
|
||||
provider := sl.Current().Interface().(ProviderConfig)
|
||||
providerCount := 0
|
||||
|
||||
if provider.Azure != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.GCP != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.QEMU != nil {
|
||||
providerCount++
|
||||
}
|
||||
|
||||
if providerCount < 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
|
||||
} else if providerCount > 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks the config values and returns validation error messages.
|
||||
// The function only returns an error if the validation itself fails.
|
||||
func (c *Config) Validate() ([]string, error) {
|
||||
@ -262,7 +285,7 @@ func (c *Config) Validate() ([]string, error) {
|
||||
}
|
||||
|
||||
// Register Azure & GCP InstanceType validation error types
|
||||
if err := validate.RegisterTranslation("azure_instance_type", trans, c.registerTranslateAzureInstanceTypeError, translateAzureInstanceTypeError); err != nil {
|
||||
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -270,6 +293,15 @@ func (c *Config) Validate() ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register Provider validation error types
|
||||
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
|
||||
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
|
||||
return nil, err
|
||||
@ -280,11 +312,14 @@ func (c *Config) Validate() ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label azure_instance_type to validate version based on available versionConfigs.
|
||||
// register custom validator with label gcp_instance_type to validate version based on available versionConfigs.
|
||||
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register provider validation
|
||||
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||
|
||||
err := validate.Struct(c)
|
||||
if err == nil {
|
||||
return nil, nil
|
||||
@ -302,18 +337,19 @@ func (c *Config) Validate() ([]string, error) {
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// Validation translation functions for Azure & GCP instance type error functions.
|
||||
func (c *Config) registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
||||
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
||||
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
||||
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureTrustedLaunchInstanceTypes), true)
|
||||
}
|
||||
// Otherwise suggest CVMs
|
||||
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureCVMInstanceTypes), true)
|
||||
// Validation translation functions for Azure & GCP instance type errors.
|
||||
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
|
||||
}
|
||||
|
||||
func translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("azure_instance_type", fe.Field())
|
||||
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
||||
var t string
|
||||
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
|
||||
} else {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
@ -328,6 +364,41 @@ func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) st
|
||||
return t
|
||||
}
|
||||
|
||||
// Validation translation functions for Provider errors.
|
||||
func registerNoProviderError(ut ut.Translator) error {
|
||||
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
|
||||
}
|
||||
|
||||
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("no_provider", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerMoreThanOneProviderError(ut ut.Translator) error {
|
||||
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
|
||||
}
|
||||
|
||||
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
definedProviders := make([]string, 0)
|
||||
|
||||
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
|
||||
if c.Provider.Azure != nil {
|
||||
definedProviders = append(definedProviders, "Azure")
|
||||
}
|
||||
if c.Provider.GCP != nil {
|
||||
definedProviders = append(definedProviders, "GCP")
|
||||
}
|
||||
if c.Provider.QEMU != nil {
|
||||
definedProviders = append(definedProviders, "QEMU")
|
||||
}
|
||||
|
||||
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
|
||||
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// HasProvider checks whether the config contains the provider.
|
||||
func (c *Config) HasProvider(provider cloudprovider.Provider) bool {
|
||||
switch provider {
|
||||
|
@ -14,13 +14,16 @@ import (
|
||||
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/go-playground/locales/en"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
const defaultMsgCount = 12 // expect this number of error messages by default because user-specific values are not set
|
||||
const defaultMsgCount = 13 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
@ -506,3 +509,86 @@ func TestIsDebugCluster(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProvider(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
provider ProviderConfig
|
||||
wantErr bool
|
||||
expectedErrorTag string
|
||||
}{
|
||||
"empty, should trigger no provider error": {
|
||||
provider: ProviderConfig{},
|
||||
wantErr: true,
|
||||
expectedErrorTag: "no_provider",
|
||||
},
|
||||
"azure only, should be okay": {
|
||||
provider: ProviderConfig{
|
||||
Azure: &AzureConfig{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"gcp only, should be okay": {
|
||||
provider: ProviderConfig{
|
||||
GCP: &GCPConfig{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"qemu only, should be okay": {
|
||||
provider: ProviderConfig{
|
||||
QEMU: &QEMUConfig{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"azure and gcp, should trigger multiple provider error": {
|
||||
provider: ProviderConfig{
|
||||
Azure: &AzureConfig{},
|
||||
GCP: &GCPConfig{},
|
||||
},
|
||||
wantErr: true,
|
||||
expectedErrorTag: "more_than_one_provider",
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
v := validator.New()
|
||||
trans := ut.New(en.New()).GetFallback()
|
||||
|
||||
conf := Default()
|
||||
conf.Provider = tc.provider
|
||||
|
||||
v.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||
err := v.StructPartial(tc.provider)
|
||||
|
||||
// Register provider validation error types.
|
||||
// Make sure the tags and expected strings below are in sync with the actual implementation.
|
||||
require.NoError(v.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError))
|
||||
require.NoError(v.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, conf.translateMoreThanOneProviderError))
|
||||
|
||||
// Continue if no error is expected.
|
||||
if !tc.wantErr {
|
||||
assert.NoError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate if the error was identified correctly.
|
||||
require.NotNil(err)
|
||||
assert.Error(err)
|
||||
assert.Contains(err.Error(), tc.expectedErrorTag)
|
||||
|
||||
// Check if error translation works correctly.
|
||||
validationErr := err.(validator.ValidationErrors)
|
||||
translatedErr := validationErr.Translate(trans)
|
||||
|
||||
// The translator does not seem to export a list of available translations or for a specific field.
|
||||
// So we need to hardcode expected strings. Needs to be in sync with implementation.
|
||||
switch tc.expectedErrorTag {
|
||||
case "no_provider":
|
||||
assert.Contains(translatedErr["ProviderConfig.Provider"], "No provider has been defined")
|
||||
case "more_than_one_provider":
|
||||
assert.Contains(translatedErr["ProviderConfig.Provider"], "Only one provider can be defined")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user