mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-12 16:09:39 -05:00
Add validation for zero or more than one provider
This commit is contained in:
parent
fb5faa681c
commit
7aded65ea8
@ -15,6 +15,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||||
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
||||||
@ -252,6 +253,28 @@ func validateGCPInstanceType(fl validator.FieldLevel) bool {
|
|||||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
|
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateProvider checks if zero or more than one providers are defined in the config.
|
||||||
|
func validateProvider(sl validator.StructLevel) {
|
||||||
|
provider := sl.Current().Interface().(ProviderConfig)
|
||||||
|
providerCount := 0
|
||||||
|
|
||||||
|
if provider.Azure != nil {
|
||||||
|
providerCount++
|
||||||
|
}
|
||||||
|
if provider.GCP != nil {
|
||||||
|
providerCount++
|
||||||
|
}
|
||||||
|
if provider.QEMU != nil {
|
||||||
|
providerCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerCount < 1 {
|
||||||
|
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
|
||||||
|
} else if providerCount > 1 {
|
||||||
|
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate checks the config values and returns validation error messages.
|
// Validate checks the config values and returns validation error messages.
|
||||||
// The function only returns an error if the validation itself fails.
|
// The function only returns an error if the validation itself fails.
|
||||||
func (c *Config) Validate() ([]string, error) {
|
func (c *Config) Validate() ([]string, error) {
|
||||||
@ -262,7 +285,7 @@ func (c *Config) Validate() ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register Azure & GCP InstanceType validation error types
|
// Register Azure & GCP InstanceType validation error types
|
||||||
if err := validate.RegisterTranslation("azure_instance_type", trans, c.registerTranslateAzureInstanceTypeError, translateAzureInstanceTypeError); err != nil {
|
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,6 +293,15 @@ func (c *Config) Validate() ([]string, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register Provider validation error types
|
||||||
|
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
|
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
|
||||||
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
|
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -280,11 +312,14 @@ func (c *Config) Validate() ([]string, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// register custom validator with label azure_instance_type to validate version based on available versionConfigs.
|
// register custom validator with label gcp_instance_type to validate version based on available versionConfigs.
|
||||||
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
|
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register provider validation
|
||||||
|
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||||
|
|
||||||
err := validate.Struct(c)
|
err := validate.Struct(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -302,18 +337,19 @@ func (c *Config) Validate() ([]string, error) {
|
|||||||
return msgs, nil
|
return msgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validation translation functions for Azure & GCP instance type error functions.
|
// Validation translation functions for Azure & GCP instance type errors.
|
||||||
func (c *Config) registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
||||||
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
|
||||||
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
|
||||||
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureTrustedLaunchInstanceTypes), true)
|
|
||||||
}
|
|
||||||
// Otherwise suggest CVMs
|
|
||||||
return ut.Add("azure_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.AzureCVMInstanceTypes), true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||||
t, _ := ut.T("azure_instance_type", fe.Field())
|
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
||||||
|
var t string
|
||||||
|
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
||||||
|
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
|
||||||
|
} else {
|
||||||
|
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
|
||||||
|
}
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
@ -328,6 +364,41 @@ func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) st
|
|||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validation translation functions for Provider errors.
|
||||||
|
func registerNoProviderError(ut ut.Translator) error {
|
||||||
|
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||||
|
t, _ := ut.T("no_provider", fe.Field())
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerMoreThanOneProviderError(ut ut.Translator) error {
|
||||||
|
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||||
|
definedProviders := make([]string, 0)
|
||||||
|
|
||||||
|
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
|
||||||
|
if c.Provider.Azure != nil {
|
||||||
|
definedProviders = append(definedProviders, "Azure")
|
||||||
|
}
|
||||||
|
if c.Provider.GCP != nil {
|
||||||
|
definedProviders = append(definedProviders, "GCP")
|
||||||
|
}
|
||||||
|
if c.Provider.QEMU != nil {
|
||||||
|
definedProviders = append(definedProviders, "QEMU")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
|
||||||
|
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
// HasProvider checks whether the config contains the provider.
|
// HasProvider checks whether the config contains the provider.
|
||||||
func (c *Config) HasProvider(provider cloudprovider.Provider) bool {
|
func (c *Config) HasProvider(provider cloudprovider.Provider) bool {
|
||||||
switch provider {
|
switch provider {
|
||||||
|
@ -14,13 +14,16 @@ import (
|
|||||||
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
"github.com/edgelesssys/constellation/internal/config/instancetypes"
|
||||||
"github.com/edgelesssys/constellation/internal/constants"
|
"github.com/edgelesssys/constellation/internal/constants"
|
||||||
"github.com/edgelesssys/constellation/internal/file"
|
"github.com/edgelesssys/constellation/internal/file"
|
||||||
|
"github.com/go-playground/locales/en"
|
||||||
|
ut "github.com/go-playground/universal-translator"
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultMsgCount = 12 // expect this number of error messages by default because user-specific values are not set
|
const defaultMsgCount = 13 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
goleak.VerifyTestMain(m)
|
goleak.VerifyTestMain(m)
|
||||||
@ -506,3 +509,86 @@ func TestIsDebugCluster(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateProvider(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
provider ProviderConfig
|
||||||
|
wantErr bool
|
||||||
|
expectedErrorTag string
|
||||||
|
}{
|
||||||
|
"empty, should trigger no provider error": {
|
||||||
|
provider: ProviderConfig{},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErrorTag: "no_provider",
|
||||||
|
},
|
||||||
|
"azure only, should be okay": {
|
||||||
|
provider: ProviderConfig{
|
||||||
|
Azure: &AzureConfig{},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"gcp only, should be okay": {
|
||||||
|
provider: ProviderConfig{
|
||||||
|
GCP: &GCPConfig{},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"qemu only, should be okay": {
|
||||||
|
provider: ProviderConfig{
|
||||||
|
QEMU: &QEMUConfig{},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"azure and gcp, should trigger multiple provider error": {
|
||||||
|
provider: ProviderConfig{
|
||||||
|
Azure: &AzureConfig{},
|
||||||
|
GCP: &GCPConfig{},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErrorTag: "more_than_one_provider",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
v := validator.New()
|
||||||
|
trans := ut.New(en.New()).GetFallback()
|
||||||
|
|
||||||
|
conf := Default()
|
||||||
|
conf.Provider = tc.provider
|
||||||
|
|
||||||
|
v.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||||
|
err := v.StructPartial(tc.provider)
|
||||||
|
|
||||||
|
// Register provider validation error types.
|
||||||
|
// Make sure the tags and expected strings below are in sync with the actual implementation.
|
||||||
|
require.NoError(v.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError))
|
||||||
|
require.NoError(v.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, conf.translateMoreThanOneProviderError))
|
||||||
|
|
||||||
|
// Continue if no error is expected.
|
||||||
|
if !tc.wantErr {
|
||||||
|
assert.NoError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate if the error was identified correctly.
|
||||||
|
require.NotNil(err)
|
||||||
|
assert.Error(err)
|
||||||
|
assert.Contains(err.Error(), tc.expectedErrorTag)
|
||||||
|
|
||||||
|
// Check if error translation works correctly.
|
||||||
|
validationErr := err.(validator.ValidationErrors)
|
||||||
|
translatedErr := validationErr.Translate(trans)
|
||||||
|
|
||||||
|
// The translator does not seem to export a list of available translations or for a specific field.
|
||||||
|
// So we need to hardcode expected strings. Needs to be in sync with implementation.
|
||||||
|
switch tc.expectedErrorTag {
|
||||||
|
case "no_provider":
|
||||||
|
assert.Contains(translatedErr["ProviderConfig.Provider"], "No provider has been defined")
|
||||||
|
case "more_than_one_provider":
|
||||||
|
assert.Contains(translatedErr["ProviderConfig.Provider"], "Only one provider can be defined")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user