mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
cli: state file validation (#2523)
* re-use `ReadFromFile` in `CreateOrRead` Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip]: add constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip] error formatting Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * wip Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * formatted error messages Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * state file validation Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * linter fixes Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * allow overriding the constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * dont validate on read Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * add pre-create constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip] Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * finish pre-init validation test Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * finish post-init validation Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * use state file validation in CLI Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * fix apply tests Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * Update internal/validation/errors.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * use transformator for tests * tidy * use empty check directly Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * conditional validation per CSP Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * tidy Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * fix rebase Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * add default case Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * validate state-file as last input Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> --------- Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
This commit is contained in:
parent
eaec73cca4
commit
744a605602
@ -419,12 +419,6 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Check license
|
||||
a.log.Debugf("Running license check")
|
||||
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
|
||||
@ -517,6 +511,27 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
|
||||
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
|
||||
}
|
||||
|
||||
// Read and validate state file
|
||||
// This needs to be done as a last step, as we need to parse all other inputs to
|
||||
// know which phases are skipped.
|
||||
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if a.flags.skipPhases.contains(skipInitPhase) {
|
||||
// If the skipInit flag is set, we are in a state where the cluster
|
||||
// has already been initialized and check against the respective constraints.
|
||||
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
// The cluster has not been initialized yet, so we check against the pre-init constraints.
|
||||
if err := stateFile.Validate(state.PreInit, conf.GetProvider()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, stateFile, nil
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/state"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
@ -22,6 +23,54 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// defaultStateFile returns a valid default state for testing.
|
||||
func defaultStateFile() *state.State {
|
||||
return &state.State{
|
||||
Version: "v1",
|
||||
Infrastructure: state.Infrastructure{
|
||||
UID: "123",
|
||||
Name: "test-cluster",
|
||||
ClusterEndpoint: "192.0.2.1",
|
||||
InClusterEndpoint: "192.0.2.1",
|
||||
InitSecret: []byte{0x41},
|
||||
APIServerCertSANs: []string{
|
||||
"127.0.0.1",
|
||||
"www.example.com",
|
||||
},
|
||||
IPCidrNode: "0.0.0.0/24",
|
||||
Azure: &state.Azure{
|
||||
ResourceGroup: "test-rg",
|
||||
SubscriptionID: "test-sub",
|
||||
NetworkSecurityGroupName: "test-nsg",
|
||||
LoadBalancerName: "test-lb",
|
||||
UserAssignedIdentity: "test-uami",
|
||||
AttestationURL: "test-maaUrl",
|
||||
},
|
||||
GCP: &state.GCP{
|
||||
ProjectID: "test-project",
|
||||
IPCidrPod: "0.0.0.0/24",
|
||||
},
|
||||
},
|
||||
ClusterValues: state.ClusterValues{
|
||||
ClusterID: "deadbeef",
|
||||
OwnerID: "deadbeef",
|
||||
MeasurementSalt: []byte{0x41},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultAzureStateFile() *state.State {
|
||||
s := defaultStateFile()
|
||||
s.Infrastructure.GCP = nil
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultGCPStateFile() *state.State {
|
||||
s := defaultStateFile()
|
||||
s.Infrastructure.Azure = nil
|
||||
return s
|
||||
}
|
||||
|
||||
func TestParseApplyFlags(t *testing.T) {
|
||||
require := require.New(t)
|
||||
defaultFlags := func() *pflag.FlagSet {
|
||||
|
@ -202,6 +202,9 @@ func (c *createCmd) create(cmd *cobra.Command, applier cloudApplier, fileHandler
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
if err := stateFile.Validate(state.PreCreate, conf.GetProvider()); err != nil {
|
||||
return fmt.Errorf("validating state file: %w", err)
|
||||
}
|
||||
stateFile = stateFile.SetInfrastructure(infraState)
|
||||
if err := stateFile.WriteToFile(fileHandler, constants.StateFilename); err != nil {
|
||||
return fmt.Errorf("writing state file: %w", err)
|
||||
|
@ -22,12 +22,21 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// preCreateStateFile returns a state file satisfying the pre-create state file
|
||||
// constraints.
|
||||
func preCreateStateFile() *state.State {
|
||||
s := defaultAzureStateFile()
|
||||
s.ClusterValues = state.ClusterValues{}
|
||||
s.Infrastructure = state.Infrastructure{}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
fsWithDefaultConfigAndState := func(require *require.Assertions, provider cloudprovider.Provider) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
file := file.NewHandler(fs)
|
||||
require.NoError(file.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), provider)))
|
||||
stateFile := state.New()
|
||||
stateFile := preCreateStateFile()
|
||||
switch provider {
|
||||
case cloudprovider.GCP:
|
||||
stateFile.SetInfrastructure(state.Infrastructure{GCP: &state.GCP{}})
|
||||
|
@ -59,6 +59,14 @@ func TestInitArgumentValidation(t *testing.T) {
|
||||
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
|
||||
}
|
||||
|
||||
// preInitStateFile returns a state file satisfying the pre-init state file
|
||||
// constraints.
|
||||
func preInitStateFile() *state.State {
|
||||
s := defaultAzureStateFile()
|
||||
s.ClusterValues = state.ClusterValues{}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
respKubeconfig := k8sclientapi.Config{
|
||||
Clusters: map[string]*k8sclientapi.Cluster{
|
||||
@ -101,24 +109,24 @@ func TestInitialize(t *testing.T) {
|
||||
}{
|
||||
"initialize some gcp instances": {
|
||||
provider: cloudprovider.GCP,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||
},
|
||||
"initialize some azure instances": {
|
||||
provider: cloudprovider.Azure,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||
},
|
||||
"initialize some qemu instances": {
|
||||
provider: cloudprovider.QEMU,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||
},
|
||||
"non retriable error": {
|
||||
provider: cloudprovider.QEMU,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{initErr: &nonRetriableError{err: assert.AnError}},
|
||||
retriable: false,
|
||||
masterSecretShouldExist: true,
|
||||
@ -126,7 +134,7 @@ func TestInitialize(t *testing.T) {
|
||||
},
|
||||
"non retriable error with failed log collection": {
|
||||
provider: cloudprovider.QEMU,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{
|
||||
res: []*initproto.InitResponse{
|
||||
{
|
||||
@ -149,31 +157,24 @@ func TestInitialize(t *testing.T) {
|
||||
masterSecretShouldExist: true,
|
||||
wantErr: true,
|
||||
},
|
||||
/*
|
||||
Tests currently disabled since we don't actually have validation for the state file yet
|
||||
These tests cases only passed in the past because of unrelated errors in the test setup
|
||||
TODO(AB#3492): Re-enable tests once state file validation is implemented
|
||||
|
||||
"state file with only version": {
|
||||
provider: cloudprovider.GCP,
|
||||
stateFile: &state.State{Version: state.Version1},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{},
|
||||
retriable: true,
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
"empty state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
stateFile: &state.State{},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{},
|
||||
retriable: true,
|
||||
wantErr: true,
|
||||
},
|
||||
*/
|
||||
"invalid state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
stateFile: &state.State{Version: "invalid"},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{},
|
||||
retriable: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"empty state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
stateFile: &state.State{},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{},
|
||||
retriable: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"no state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
@ -184,7 +185,7 @@ func TestInitialize(t *testing.T) {
|
||||
"init call fails": {
|
||||
provider: cloudprovider.GCP,
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{initErr: assert.AnError},
|
||||
retriable: false,
|
||||
@ -193,7 +194,7 @@ func TestInitialize(t *testing.T) {
|
||||
},
|
||||
"k8s version without v works": {
|
||||
provider: cloudprovider.Azure,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||
configMutator: func(c *config.Config) {
|
||||
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
|
||||
@ -203,7 +204,7 @@ func TestInitialize(t *testing.T) {
|
||||
},
|
||||
"outdated k8s patch version doesn't work": {
|
||||
provider: cloudprovider.Azure,
|
||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
stateFile: preInitStateFile(),
|
||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||
configMutator: func(c *config.Config) {
|
||||
v, err := semver.New(versions.SupportedK8sVersions()[0])
|
||||
|
@ -119,6 +119,10 @@ func (r *recoverCmd) recover(
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
if err := stateFile.Validate(state.PostInit, provider); err != nil {
|
||||
return fmt.Errorf("validating state file: %w", err)
|
||||
}
|
||||
|
||||
endpoint, err := r.parseEndpoint(stateFile)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/state"
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
@ -159,7 +158,7 @@ func TestRecover(t *testing.T) {
|
||||
))
|
||||
require.NoError(fileHandler.WriteYAML(
|
||||
constants.StateFilename,
|
||||
state.New(),
|
||||
defaultGCPStateFile(),
|
||||
file.OptNone,
|
||||
))
|
||||
|
||||
|
@ -33,18 +33,10 @@ import (
|
||||
)
|
||||
|
||||
func TestUpgradeApply(t *testing.T) {
|
||||
defaultState := state.New().
|
||||
SetInfrastructure(state.Infrastructure{
|
||||
APIServerCertSANs: []string{},
|
||||
UID: "uid",
|
||||
Name: "kubernetes-uid", // default test cfg uses "kubernetes" prefix
|
||||
InitSecret: []byte{0x42},
|
||||
}).
|
||||
SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}})
|
||||
fsWithStateFileAndTfState := func() file.Handler {
|
||||
fh := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
|
||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
|
||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
|
||||
return fh
|
||||
}
|
||||
|
||||
@ -63,20 +55,20 @@ func TestUpgradeApply(t *testing.T) {
|
||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
|
||||
gotState, err := state.ReadFromFile(fh, constants.StateFilename)
|
||||
require.NoError(err)
|
||||
assert.Equal("v1", gotState.Version)
|
||||
assert.Equal(defaultState, gotState)
|
||||
assert.Equal(defaultAzureStateFile(), gotState)
|
||||
},
|
||||
},
|
||||
"id file and state file do not exist": {
|
||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: func() file.Handler {
|
||||
return file.NewHandler(afero.NewMemMapFs())
|
||||
},
|
||||
@ -90,7 +82,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
wantErr: true,
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"nodeVersion in progress error": {
|
||||
@ -100,7 +92,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
},
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"helm other error": {
|
||||
@ -110,7 +102,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
helmUpgrader: stubApplier{err: assert.AnError},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
wantErr: true,
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"abort": {
|
||||
@ -140,7 +132,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
|
||||
wantErr: true,
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"apply terraform error": {
|
||||
@ -153,7 +145,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
terraformDiff: true,
|
||||
},
|
||||
wantErr: true,
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"outdated K8s patch version": {
|
||||
@ -167,7 +159,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
|
||||
}(),
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
"outdated K8s version": {
|
||||
@ -177,7 +169,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
customK8sVersion: "v1.20.0",
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
wantErr: true,
|
||||
fh: fsWithStateFileAndTfState,
|
||||
},
|
||||
@ -191,6 +183,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
skipPhases: skipPhases{
|
||||
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
||||
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
|
||||
skipInitPhase: struct{}{},
|
||||
},
|
||||
yes: true,
|
||||
},
|
||||
@ -205,7 +198,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
flags: applyFlags{
|
||||
skipPhases: skipPhases{
|
||||
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
||||
skipK8sPhase: struct{}{},
|
||||
skipK8sPhase: struct{}{}, skipInitPhase: struct{}{},
|
||||
},
|
||||
yes: true,
|
||||
},
|
||||
@ -219,10 +212,13 @@ func TestUpgradeApply(t *testing.T) {
|
||||
terraformUpgrader: &mockTerraformUpgrader{},
|
||||
flags: applyFlags{
|
||||
yes: true,
|
||||
skipPhases: skipPhases{
|
||||
skipInitPhase: struct{}{},
|
||||
},
|
||||
},
|
||||
fh: func() file.Handler {
|
||||
fh := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
|
||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
|
||||
return fh
|
||||
},
|
||||
},
|
||||
@ -230,7 +226,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
|
||||
helmUpgrader: stubApplier{},
|
||||
terraformUpgrader: &stubTerraformUpgrader{},
|
||||
flags: applyFlags{yes: true},
|
||||
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||
fh: fsWithStateFileAndTfState,
|
||||
wantErr: true,
|
||||
},
|
||||
|
@ -155,6 +155,9 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
|
||||
return fmt.Errorf("validating state file: %w", err)
|
||||
}
|
||||
|
||||
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
|
||||
if err != nil {
|
||||
|
@ -48,7 +48,7 @@ func TestVerify(t *testing.T) {
|
||||
formatter *stubAttDocFormatter
|
||||
nodeEndpointFlag string
|
||||
clusterIDFlag string
|
||||
stateFile *state.State
|
||||
stateFile func() *state.State
|
||||
wantEndpoint string
|
||||
skipConfigCreation bool
|
||||
wantErr bool
|
||||
@ -58,7 +58,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultGCPStateFile,
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
@ -67,7 +67,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultAzureStateFile,
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
@ -76,7 +76,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultGCPStateFile,
|
||||
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
@ -84,56 +84,78 @@ func TestVerify(t *testing.T) {
|
||||
provider: cloudprovider.GCP,
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
stateFile: func() *state.State {
|
||||
s := defaultGCPStateFile()
|
||||
s.Infrastructure.ClusterEndpoint = ""
|
||||
return s
|
||||
},
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
},
|
||||
"endpoint from state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
||||
formatter: &stubAttDocFormatter{},
|
||||
stateFile: func() *state.State {
|
||||
s := defaultGCPStateFile()
|
||||
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
|
||||
return s
|
||||
},
|
||||
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
"override endpoint from details file": {
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.2:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
||||
wantEndpoint: "192.0.2.2:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
stateFile: func() *state.State {
|
||||
s := defaultGCPStateFile()
|
||||
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
|
||||
return s
|
||||
},
|
||||
wantEndpoint: "192.0.2.2:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
"invalid endpoint": {
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: ":::::",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultGCPStateFile,
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
},
|
||||
"neither owner id nor cluster id set": {
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
stateFile: state.New(),
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
stateFile: func() *state.State {
|
||||
s := defaultGCPStateFile()
|
||||
s.ClusterValues.OwnerID = ""
|
||||
s.ClusterValues.ClusterID = ""
|
||||
return s
|
||||
},
|
||||
formatter: &stubAttDocFormatter{},
|
||||
protoClient: &stubVerifyClient{},
|
||||
wantErr: true,
|
||||
},
|
||||
"use owner id from state file": {
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: &state.State{ClusterValues: state.ClusterValues{OwnerID: zeroBase64}},
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
stateFile: func() *state.State {
|
||||
s := defaultGCPStateFile()
|
||||
s.ClusterValues.OwnerID = zeroBase64
|
||||
return s
|
||||
},
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
formatter: &stubAttDocFormatter{},
|
||||
},
|
||||
"config file not existing": {
|
||||
provider: cloudprovider.GCP,
|
||||
clusterIDFlag: zeroBase64,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultGCPStateFile,
|
||||
formatter: &stubAttDocFormatter{},
|
||||
skipConfigCreation: true,
|
||||
wantErr: true,
|
||||
@ -143,7 +165,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultAzureStateFile,
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
},
|
||||
@ -152,7 +174,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{verifyErr: someErr},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultAzureStateFile,
|
||||
formatter: &stubAttDocFormatter{},
|
||||
wantErr: true,
|
||||
},
|
||||
@ -161,7 +183,7 @@ func TestVerify(t *testing.T) {
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
clusterIDFlag: zeroBase64,
|
||||
protoClient: &stubVerifyClient{},
|
||||
stateFile: state.New(),
|
||||
stateFile: defaultAzureStateFile,
|
||||
wantEndpoint: "192.0.2.1:1234",
|
||||
formatter: &stubAttDocFormatter{formatErr: someErr},
|
||||
wantErr: true,
|
||||
@ -182,7 +204,7 @@ func TestVerify(t *testing.T) {
|
||||
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
|
||||
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
|
||||
}
|
||||
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
|
||||
require.NoError(tc.stateFile().WriteToFile(fileHandler, constants.StateFilename))
|
||||
|
||||
v := &verifyCmd{
|
||||
fileHandler: fileHandler,
|
||||
|
@ -10,7 +10,9 @@ go_library(
|
||||
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/state",
|
||||
visibility = ["//cli:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/cloud/cloudprovider",
|
||||
"//internal/file",
|
||||
"//internal/validation",
|
||||
"@cat_dario_mergo//:mergo",
|
||||
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
||||
],
|
||||
@ -18,9 +20,13 @@ go_library(
|
||||
|
||||
go_test(
|
||||
name = "state_test",
|
||||
srcs = ["state_test.go"],
|
||||
srcs = [
|
||||
"state_test.go",
|
||||
"validation_test.go",
|
||||
],
|
||||
embed = [":state"],
|
||||
deps = [
|
||||
"//internal/cloud/cloudprovider",
|
||||
"//internal/constants",
|
||||
"//internal/file",
|
||||
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
||||
|
@ -19,7 +19,9 @@ import (
|
||||
"os"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/validation"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -27,25 +29,44 @@ const (
|
||||
Version1 = "v1"
|
||||
)
|
||||
|
||||
// ReadFromFile reads the state file at the given path and returns the state.
|
||||
const (
|
||||
// PreCreate are the constraints that should be enforced when the state file
|
||||
// is validated before cloud infrastructure is created.
|
||||
PreCreate ConstraintSet = iota
|
||||
// PreInit are the constraints that should be enforced when the state file
|
||||
// is validated before the first Constellation node is initialized.
|
||||
PreInit
|
||||
// PostInit are the constraints that should be enforced when the state file
|
||||
// is validated after the cluster was initialized.
|
||||
PostInit
|
||||
)
|
||||
|
||||
// ConstraintSet defines which constraints the state file
|
||||
// should be validated against.
|
||||
type ConstraintSet int
|
||||
|
||||
// ReadFromFile reads the state file at the given path and validates it.
|
||||
// If the state file is valid, the state is returned. Otherwise, an error
|
||||
// describing why the validation failed is returned.
|
||||
func ReadFromFile(fileHandler file.Handler, path string) (*State, error) {
|
||||
state := &State{}
|
||||
if err := fileHandler.ReadYAML(path, &state); err != nil {
|
||||
return nil, fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// CreateOrRead reads the state file at the given path, if it exists, and returns the state.
|
||||
// If the file does not exist, a new state is created and written to disk.
|
||||
func CreateOrRead(fileHandler file.Handler, path string) (*State, error) {
|
||||
state := &State{}
|
||||
if err := fileHandler.ReadYAML(path, &state); err != nil {
|
||||
state, err := ReadFromFile(fileHandler, path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("reading state file: %w", err)
|
||||
}
|
||||
state = New()
|
||||
return state, state.WriteToFile(fileHandler, path)
|
||||
newState := New()
|
||||
return newState, newState.WriteToFile(fileHandler, path)
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@ -186,6 +207,349 @@ func (s *State) Merge(other *State) (*State, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
/*
|
||||
Validate validates the state against the given constraint set and CSP, which can be one of
|
||||
- PreCreate, which is the constraint set that should be enforced before "constellation create" is run.
|
||||
- PreInit, which is the constraint set that should be enforced before "constellation apply" is run.
|
||||
- PostInit, which is the constraint set that should be enforced after "constellation apply" is run.
|
||||
*/
|
||||
func (s *State) Validate(constraintSet ConstraintSet, csp cloudprovider.Provider) error {
|
||||
v := validation.NewValidator()
|
||||
|
||||
switch constraintSet {
|
||||
case PreCreate:
|
||||
return v.Validate(s, validation.ValidateOptions{
|
||||
OverrideConstraints: s.preCreateConstraints,
|
||||
})
|
||||
case PreInit:
|
||||
return v.Validate(s, validation.ValidateOptions{
|
||||
OverrideConstraints: s.preInitConstraints,
|
||||
})
|
||||
case PostInit:
|
||||
return v.Validate(s, validation.ValidateOptions{
|
||||
OverrideConstraints: s.postInitConstraints(csp),
|
||||
})
|
||||
default:
|
||||
return errors.New("unknown constraint set")
|
||||
}
|
||||
}
|
||||
|
||||
// preCreateConstraints are the constraints on the state that should be enforced
|
||||
// before a Constellation cluster is created.
|
||||
//
|
||||
// The constraints check if the state file version is valid,
|
||||
// and if all fields are empty, which is a requirement pre-create.
|
||||
func (s *State) preCreateConstraints() []*validation.Constraint {
|
||||
return []*validation.Constraint{
|
||||
// state version needs to be accepted by the parsing CLI.
|
||||
validation.OneOf(s.Version, []string{Version1}).
|
||||
WithFieldTrace(s, &s.Version),
|
||||
// Infrastructure must be empty.
|
||||
// As the infrastructure struct contains slices, we cannot use the
|
||||
// Empty constraint on the entire struct. Instead, we need to check
|
||||
// each field individually.
|
||||
validation.Empty(s.Infrastructure.UID).
|
||||
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||
validation.Empty(s.Infrastructure.ClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||
validation.Empty(s.Infrastructure.InClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||
validation.Empty(s.Infrastructure.Name).
|
||||
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||
validation.Empty(s.Infrastructure.IPCidrNode).
|
||||
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||
validation.EmptySlice(s.Infrastructure.APIServerCertSANs).
|
||||
WithFieldTrace(s, &s.Infrastructure.APIServerCertSANs),
|
||||
validation.EmptySlice(s.Infrastructure.InitSecret).
|
||||
WithFieldTrace(s, &s.Infrastructure.InitSecret),
|
||||
// ClusterValues must be empty.
|
||||
// As the clusterValues struct contains slices, we cannot use the
|
||||
// Empty constraint on the entire struct. Instead, we need to check
|
||||
// each field individually.
|
||||
validation.Empty(s.ClusterValues.ClusterID).
|
||||
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||
validation.Empty(s.ClusterValues.OwnerID).
|
||||
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
|
||||
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||
}
|
||||
}
|
||||
|
||||
// preInitConstraints are the constraints on the state that should be enforced
|
||||
// *before* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
|
||||
//
|
||||
// The constraints check if the infrastructure state is valid, and if the cluster values
|
||||
// are empty, which is required for the cluster to initialize correctly.
|
||||
func (s *State) preInitConstraints() []*validation.Constraint {
|
||||
return []*validation.Constraint{
|
||||
// state version needs to be accepted by the parsing CLI.
|
||||
validation.OneOf(s.Version, []string{Version1}).
|
||||
WithFieldTrace(s, &s.Version),
|
||||
// infrastructure must be valid.
|
||||
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
|
||||
validation.Or(
|
||||
validation.DNSName(s.Infrastructure.ClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||
),
|
||||
// in-cluster endpoint needs to be a valid DNS name or IP address.
|
||||
validation.Or(
|
||||
validation.DNSName(s.Infrastructure.InClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||
),
|
||||
// Node IP Cidr needs to be a valid CIDR range.
|
||||
validation.CIDR(s.Infrastructure.IPCidrNode).
|
||||
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||
// UID needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.UID).
|
||||
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||
// Name needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.Name).
|
||||
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||
// GCP values need to be nil, empty, or valid.
|
||||
validation.Or(
|
||||
validation.Or(
|
||||
// nil.
|
||||
validation.Equal(s.Infrastructure.GCP, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||
// empty.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.GCP,
|
||||
func() *validation.Constraint {
|
||||
return validation.Empty(*s.Infrastructure.GCP).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||
},
|
||||
),
|
||||
),
|
||||
// valid.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.GCP,
|
||||
func() *validation.Constraint {
|
||||
return validation.And(
|
||||
validation.EvaluateAll,
|
||||
// ProjectID needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
|
||||
// Pod IP Cidr needs to be a valid CIDR range.
|
||||
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
// Azure values need to be nil, empty, or valid.
|
||||
validation.Or(
|
||||
validation.Or(
|
||||
// nil.
|
||||
validation.Equal(s.Infrastructure.Azure, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||
// empty.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.Azure,
|
||||
func() *validation.Constraint {
|
||||
return validation.And(
|
||||
validation.EvaluateAll,
|
||||
validation.Empty(s.Infrastructure.Azure.ResourceGroup).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||
validation.Empty(s.Infrastructure.Azure.SubscriptionID).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||
validation.Empty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||
validation.Empty(s.Infrastructure.Azure.LoadBalancerName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||
validation.Empty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||
validation.Empty(s.Infrastructure.Azure.AttestationURL).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
// valid.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.Azure,
|
||||
func() *validation.Constraint {
|
||||
return validation.And(
|
||||
validation.EvaluateAll,
|
||||
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
// ClusterValues must be empty.
|
||||
// As the clusterValues struct contains slices, we cannot use the
|
||||
// Empty constraint on the entire struct. Instead, we need to check
|
||||
// each field individually.
|
||||
validation.Empty(s.ClusterValues.ClusterID).
|
||||
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||
validation.Empty(s.ClusterValues.OwnerID).
|
||||
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
|
||||
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||
}
|
||||
}
|
||||
|
||||
// postInitConstraints are the constraints on the state that should be enforced
|
||||
// *after* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
|
||||
//
|
||||
// The constraints check if the infrastructure state and cluster state
|
||||
// is valid, so that the cluster can be used correctly.
|
||||
func (s *State) postInitConstraints(csp cloudprovider.Provider) func() []*validation.Constraint {
|
||||
return func() []*validation.Constraint {
|
||||
constraints := []*validation.Constraint{
|
||||
// state version needs to be accepted by the parsing CLI.
|
||||
validation.OneOf(s.Version, []string{Version1}).
|
||||
WithFieldTrace(s, &s.Version),
|
||||
// infrastructure must be valid.
|
||||
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
|
||||
validation.Or(
|
||||
validation.DNSName(s.Infrastructure.ClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||
),
|
||||
// in-cluster endpoint needs to be a valid DNS name or IP address.
|
||||
validation.Or(
|
||||
validation.DNSName(s.Infrastructure.InClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
|
||||
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||
),
|
||||
// Node IP Cidr needs to be a valid CIDR range.
|
||||
validation.CIDR(s.Infrastructure.IPCidrNode).
|
||||
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||
// UID needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.UID).
|
||||
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||
// Name needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.Name).
|
||||
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||
// ClusterValues need to be valid.
|
||||
// ClusterID needs to be filled.
|
||||
validation.NotEmpty(s.ClusterValues.ClusterID).
|
||||
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||
// OwnerID needs to be filled.
|
||||
validation.NotEmpty(s.ClusterValues.OwnerID).
|
||||
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||
// MeasurementSalt needs to be filled.
|
||||
validation.NotEmptySlice(s.ClusterValues.MeasurementSalt).
|
||||
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||
}
|
||||
|
||||
switch csp {
|
||||
case cloudprovider.Azure:
|
||||
constraints = append(constraints,
|
||||
// GCP values need to be nil or empty.
|
||||
validation.Or(
|
||||
validation.Equal(s.Infrastructure.GCP, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.GCP,
|
||||
func() *validation.Constraint {
|
||||
return validation.Empty(s.Infrastructure.GCP).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||
},
|
||||
)),
|
||||
// Azure values need to be valid.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.Azure,
|
||||
func() *validation.Constraint {
|
||||
return validation.And(
|
||||
validation.EvaluateAll,
|
||||
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
case cloudprovider.GCP:
|
||||
constraints = append(constraints,
|
||||
// Azure values need to be nil or empty.
|
||||
validation.Or(
|
||||
validation.Equal(s.Infrastructure.Azure, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.Azure,
|
||||
func() *validation.Constraint {
|
||||
return validation.Empty(s.Infrastructure.Azure).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure)
|
||||
},
|
||||
)),
|
||||
// GCP values need to be valid.
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.GCP,
|
||||
func() *validation.Constraint {
|
||||
return validation.And(
|
||||
validation.EvaluateAll,
|
||||
// ProjectID needs to be filled.
|
||||
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
|
||||
// Pod IP Cidr needs to be a valid CIDR range.
|
||||
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
default:
|
||||
constraints = append(constraints,
|
||||
// GCP values need to be nil or empty.
|
||||
validation.Or(
|
||||
validation.Equal(s.Infrastructure.GCP, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.GCP,
|
||||
func() *validation.Constraint {
|
||||
return validation.Empty(s.Infrastructure.GCP).
|
||||
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||
},
|
||||
)),
|
||||
// Azure values need to be nil or empty.
|
||||
validation.Or(
|
||||
validation.Equal(s.Infrastructure.Azure, nil).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||
validation.IfNotNil(
|
||||
s.Infrastructure.Azure,
|
||||
func() *validation.Constraint {
|
||||
return validation.Empty(s.Infrastructure.Azure).
|
||||
WithFieldTrace(s, &s.Infrastructure.Azure)
|
||||
},
|
||||
)),
|
||||
)
|
||||
}
|
||||
return constraints
|
||||
}
|
||||
}
|
||||
|
||||
// Constraints is a no-op implementation to fulfill the "Validatable" interface.
|
||||
func (s *State) Constraints() []*validation.Constraint {
|
||||
return []*validation.Constraint{}
|
||||
}
|
||||
|
||||
// HexBytes is a byte slice that is marshalled to and from a hex string.
|
||||
type HexBytes []byte
|
||||
|
||||
|
@ -18,18 +18,21 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// defaultState returns a valid default state for testing.
|
||||
func defaultState() *State {
|
||||
return &State{
|
||||
Version: "v1",
|
||||
Infrastructure: Infrastructure{
|
||||
UID: "123",
|
||||
ClusterEndpoint: "test-cluster-endpoint",
|
||||
InitSecret: []byte{0x41},
|
||||
UID: "123",
|
||||
Name: "test-cluster",
|
||||
ClusterEndpoint: "0.0.0.0",
|
||||
InClusterEndpoint: "0.0.0.0",
|
||||
InitSecret: []byte{0x41},
|
||||
APIServerCertSANs: []string{
|
||||
"api-server-cert-san-test",
|
||||
"api-server-cert-san-test-2",
|
||||
"127.0.0.1",
|
||||
"www.example.com",
|
||||
},
|
||||
IPCidrNode: "test-cidr-node",
|
||||
IPCidrNode: "0.0.0.0/24",
|
||||
Azure: &Azure{
|
||||
ResourceGroup: "test-rg",
|
||||
SubscriptionID: "test-sub",
|
||||
@ -40,7 +43,7 @@ func defaultState() *State {
|
||||
},
|
||||
GCP: &GCP{
|
||||
ProjectID: "test-project",
|
||||
IPCidrPod: "test-cidr-pod",
|
||||
IPCidrPod: "0.0.0.0/24",
|
||||
},
|
||||
},
|
||||
ClusterValues: ClusterValues{
|
||||
@ -51,6 +54,18 @@ func defaultState() *State {
|
||||
}
|
||||
}
|
||||
|
||||
func defaultAzureState() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.GCP = nil
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultGCPState() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.Azure = nil
|
||||
return s
|
||||
}
|
||||
|
||||
func TestWriteToFile(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
state *State
|
||||
|
379
cli/internal/state/validation_test.go
Normal file
379
cli/internal/state/validation_test.go
Normal file
@ -0,0 +1,379 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPreCreateValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
stateFile func() *State
|
||||
wantErr bool
|
||||
errAssertions func(a *assert.Assertions, err error)
|
||||
}{
|
||||
"valid": {
|
||||
stateFile: func() *State {
|
||||
return &State{
|
||||
Version: Version1,
|
||||
}
|
||||
},
|
||||
},
|
||||
"invalid version": {
|
||||
stateFile: func() *State {
|
||||
return &State{
|
||||
Version: "invalid",
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||
},
|
||||
},
|
||||
"infrastructure not empty": {
|
||||
stateFile: func() *State {
|
||||
return &State{
|
||||
Version: Version1,
|
||||
Infrastructure: Infrastructure{
|
||||
ClusterEndpoint: "test",
|
||||
},
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: test must be empty")
|
||||
},
|
||||
},
|
||||
"cluster values not empty": {
|
||||
stateFile: func() *State {
|
||||
return &State{
|
||||
Version: Version1,
|
||||
ClusterValues: ClusterValues{
|
||||
ClusterID: "test",
|
||||
},
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.clusterValues.clusterID: test must be empty")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.stateFile().Validate(PreCreate, cloudprovider.Azure)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
if tc.errAssertions != nil {
|
||||
tc.errAssertions(assert.New(t), err)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreInitValidation(t *testing.T) {
|
||||
validPreInitState := func() *State {
|
||||
s := defaultState()
|
||||
s.ClusterValues = ClusterValues{}
|
||||
return s
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
stateFile func() *State
|
||||
wantErr bool
|
||||
errAssertions func(a *assert.Assertions, err error)
|
||||
}{
|
||||
"valid": {
|
||||
stateFile: validPreInitState,
|
||||
},
|
||||
"invalid version": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Version = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||
},
|
||||
},
|
||||
"cluster endpoint invalid": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.ClusterEndpoint = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
|
||||
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
|
||||
},
|
||||
},
|
||||
"in-cluster endpoint invalid": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.InClusterEndpoint = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
|
||||
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
|
||||
},
|
||||
},
|
||||
"node ip cidr invalid": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.IPCidrNode = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
|
||||
},
|
||||
},
|
||||
"uid empty": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.UID = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
|
||||
},
|
||||
},
|
||||
"name empty": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.Name = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
|
||||
},
|
||||
},
|
||||
"gcp empty": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.GCP = &GCP{}
|
||||
return s
|
||||
},
|
||||
},
|
||||
"gcp nil": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.GCP = nil
|
||||
return s
|
||||
},
|
||||
},
|
||||
"gcp invalid": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.GCP.IPCidrPod = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.gcp.ipCidrPod: invalid must be a valid CIDR")
|
||||
},
|
||||
},
|
||||
"azure empty": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.Azure = &Azure{}
|
||||
return s
|
||||
},
|
||||
},
|
||||
"azure nil": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.Azure = nil
|
||||
return s
|
||||
},
|
||||
},
|
||||
"azure invalid": {
|
||||
stateFile: func() *State {
|
||||
s := validPreInitState()
|
||||
s.Infrastructure.Azure.NetworkSecurityGroupName = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.azure.networkSecurityGroupName: must not be empty")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.stateFile().Validate(PreInit, cloudprovider.Azure)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
if tc.errAssertions != nil {
|
||||
tc.errAssertions(assert.New(t), err)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostInitValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
stateFile func() *State
|
||||
provider cloudprovider.Provider
|
||||
wantErr bool
|
||||
errAssertions func(a *assert.Assertions, err error)
|
||||
}{
|
||||
"valid": {
|
||||
stateFile: defaultGCPState,
|
||||
provider: cloudprovider.GCP,
|
||||
},
|
||||
"invalid version": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Version = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||
},
|
||||
},
|
||||
"cluster endpoint invalid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.ClusterEndpoint = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
|
||||
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
|
||||
},
|
||||
},
|
||||
"in-cluster endpoint invalid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.InClusterEndpoint = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
|
||||
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
|
||||
},
|
||||
},
|
||||
"node ip cidr invalid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.IPCidrNode = "invalid"
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
|
||||
},
|
||||
},
|
||||
"uid empty": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.UID = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
|
||||
},
|
||||
},
|
||||
"name empty": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.Infrastructure.Name = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
|
||||
},
|
||||
},
|
||||
"gcp valid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultGCPState()
|
||||
return s
|
||||
},
|
||||
provider: cloudprovider.GCP,
|
||||
},
|
||||
"azure valid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultAzureState()
|
||||
return s
|
||||
},
|
||||
provider: cloudprovider.Azure,
|
||||
},
|
||||
"gcp, azure not nil": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
return s
|
||||
},
|
||||
provider: cloudprovider.GCP,
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "must be equal to <nil>")
|
||||
a.Contains(err.Error(), "must be empty")
|
||||
},
|
||||
},
|
||||
"azure, gcp not nil": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
return s
|
||||
},
|
||||
provider: cloudprovider.Azure,
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "must be equal to <nil>")
|
||||
a.Contains(err.Error(), "must be empty")
|
||||
},
|
||||
},
|
||||
"cluster values invalid": {
|
||||
stateFile: func() *State {
|
||||
s := defaultState()
|
||||
s.ClusterValues.ClusterID = ""
|
||||
return s
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertions: func(a *assert.Assertions, err error) {
|
||||
a.Contains(err.Error(), "validating State.clusterValues.clusterID: must not be empty")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.stateFile().Validate(PostInit, tc.provider)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
if tc.errAssertions != nil {
|
||||
tc.errAssertions(assert.New(t), err)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -15,6 +15,7 @@ go_library(
|
||||
go_test(
|
||||
name = "validation_test",
|
||||
srcs = [
|
||||
"constraints_test.go",
|
||||
"errors_test.go",
|
||||
"validation_test.go",
|
||||
],
|
||||
|
@ -8,6 +8,7 @@ package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
@ -15,8 +16,10 @@ import (
|
||||
// Constraint is a constraint on a document or a field of a document.
|
||||
type Constraint struct {
|
||||
// Satisfied returns no error if the constraint is satisfied.
|
||||
// Otherwise, it returns the reason why the constraint is not satisfied.
|
||||
Satisfied func() error
|
||||
// Otherwise, it returns the reason why the constraint is not satisfied,
|
||||
// possibly including its child errors, i.e., errors returned by constraints
|
||||
// that are embedded in this constraint.
|
||||
Satisfied func() *TreeError
|
||||
}
|
||||
|
||||
/*
|
||||
@ -36,7 +39,7 @@ Example for a pointer field:
|
||||
Due to Go's addressability limititations regarding maps, if a map field is
|
||||
to be validated, WithMapFieldTrace must be used instead of WithFieldTrace.
|
||||
*/
|
||||
func (c *Constraint) WithFieldTrace(doc any, field any) Constraint {
|
||||
func (c *Constraint) WithFieldTrace(doc any, field any) *Constraint {
|
||||
// we only want to dereference the needle once to dereference the pointer
|
||||
// used to pass it to the function without losing reference to it, as the
|
||||
// needle could be an arbitrarily long chain of pointers. The same
|
||||
@ -69,7 +72,7 @@ Example:
|
||||
|
||||
For non-map fields, WithFieldTrace should be used instead of WithMapFieldTrace.
|
||||
*/
|
||||
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constraint {
|
||||
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) *Constraint {
|
||||
// we only want to dereference the needle once to dereference the pointer
|
||||
// used to pass it to the function without losing reference to it, as the
|
||||
// needle could be an arbitrarily long chain of pointers. The same
|
||||
@ -91,11 +94,11 @@ func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constr
|
||||
}
|
||||
|
||||
// withTrace wraps the constraint's error message with a well-formatted trace.
|
||||
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
|
||||
return Constraint{
|
||||
Satisfied: func() error {
|
||||
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if err := c.Satisfied(); err != nil {
|
||||
return newError(docRef, fieldRef, err)
|
||||
return newTraceError(docRef, fieldRef, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@ -105,49 +108,207 @@ func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
|
||||
// MatchRegex is a constraint that if s matches regex.
|
||||
func MatchRegex(s string, regex string) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() error {
|
||||
Satisfied: func() *TreeError {
|
||||
if !regexp.MustCompile(regex).MatchString(s) {
|
||||
return fmt.Errorf("%s must match the pattern %s", s, regex)
|
||||
return NewErrorTree(fmt.Errorf("%s must match the pattern %s", s, regex))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Equal is a constraint that if s is equal to t.
|
||||
// Equal is a constraint that checks if s is equal to t.
|
||||
func Equal[T comparable](s T, t T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() error {
|
||||
Satisfied: func() *TreeError {
|
||||
if s != t {
|
||||
return fmt.Errorf("%v must be equal to %v", s, t)
|
||||
return NewErrorTree(fmt.Errorf("%v must be equal to %v", s, t))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NotEmpty is a constraint that if s is not empty.
|
||||
func NotEmpty[T comparable](s T) *Constraint {
|
||||
// NotEqual is a constraint that checks if s is not equal to t.
|
||||
func NotEqual[T comparable](s T, t T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() error {
|
||||
var zero T
|
||||
if s == zero {
|
||||
return fmt.Errorf("%v must not be empty", s)
|
||||
Satisfied: func() *TreeError {
|
||||
if Equal(s, t).Satisfied() == nil {
|
||||
return NewErrorTree(fmt.Errorf("%v must not be equal to %v", s, t))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Empty is a constraint that if s is empty.
|
||||
// Empty is a constraint that checks if s is empty.
|
||||
func Empty[T comparable](s T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() error {
|
||||
Satisfied: func() *TreeError {
|
||||
var zero T
|
||||
if s != zero {
|
||||
return fmt.Errorf("%v must be empty", s)
|
||||
return NewErrorTree(fmt.Errorf("%v must be empty", s))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NotEmpty is a constraint that checks if s is not empty.
|
||||
func NotEmpty[T comparable](s T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if Empty(s).Satisfied() == nil {
|
||||
return NewErrorTree(fmt.Errorf("must not be empty"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OneOf is a constraint that s is in the set of values p.
|
||||
func OneOf[T comparable](s T, p []T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
for _, v := range p {
|
||||
if s == v {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return NewErrorTree(fmt.Errorf("%v must be one of %v", s, p))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IPAddress is a constraint that checks if s is a valid IP address.
|
||||
func IPAddress(s string) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if net.ParseIP(s) == nil {
|
||||
return NewErrorTree(fmt.Errorf("%s must be a valid IP address", s))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CIDR is a constraint that checks if s is a valid CIDR.
|
||||
func CIDR(s string) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if _, _, err := net.ParseCIDR(s); err != nil {
|
||||
return NewErrorTree(fmt.Errorf("%s must be a valid CIDR", s))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DNSName is a constraint that checks if s is a valid DNS name.
|
||||
func DNSName(s string) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if _, err := net.LookupHost(s); err != nil {
|
||||
return NewErrorTree(fmt.Errorf("%s must be a valid DNS name", s))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// EmptySlice is a constraint that checks if s is an empty slice.
|
||||
func EmptySlice[T comparable](s []T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if len(s) != 0 {
|
||||
return NewErrorTree(fmt.Errorf("%v must be empty", s))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NotEmptySlice is a constraint that checks if slice s is not empty.
|
||||
func NotEmptySlice[T comparable](s []T) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if EmptySlice(s).Satisfied() == nil {
|
||||
return NewErrorTree(fmt.Errorf("must not be empty"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// All is a constraint that checks if all elements of s satisfy the constraint c.
|
||||
// The constraint should be parametric in regards to the index of the element in s,
|
||||
// as well as the element itself.
|
||||
func All[T comparable](s []T, c func(i int, v T) *Constraint) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
|
||||
for i, v := range s {
|
||||
if err := c(i, v).Satisfied(); err != nil {
|
||||
retErr.appendChild(err)
|
||||
}
|
||||
}
|
||||
if len(retErr.children) == 0 {
|
||||
return nil
|
||||
}
|
||||
return retErr
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// And groups multiple constraints in an "and" relation and fails according to the given strategy.
|
||||
func And(errStrat ErrStrategy, constraints ...*Constraint) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
|
||||
for _, constraint := range constraints {
|
||||
if err := constraint.Satisfied(); err != nil {
|
||||
if errStrat == FailFast {
|
||||
return err
|
||||
}
|
||||
retErr.appendChild(err)
|
||||
}
|
||||
}
|
||||
if len(retErr.children) == 0 {
|
||||
return nil
|
||||
}
|
||||
return retErr
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Or groups multiple constraints in an "or" relation.
|
||||
func Or(constraints ...*Constraint) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
retErr := NewErrorTree(fmt.Errorf("at least one of the constraints must be satisfied: "))
|
||||
for _, constraint := range constraints {
|
||||
err := constraint.Satisfied()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
retErr.appendChild(err)
|
||||
}
|
||||
if len(retErr.children) == 0 {
|
||||
return nil
|
||||
}
|
||||
return retErr
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IfNotNil evaluates a constraint if and only if s is not nil.
|
||||
func IfNotNil[T comparable](s *T, c func() *Constraint) *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return c().Satisfied()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
290
internal/validation/constraints_test.go
Normal file
290
internal/validation/constraints_test.go
Normal file
@ -0,0 +1,290 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPAddress(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ip string
|
||||
wantErr bool
|
||||
}{
|
||||
"valid ipv4": {
|
||||
ip: "127.0.0.1",
|
||||
},
|
||||
"valid ipv6": {
|
||||
ip: "2001:db8::68",
|
||||
},
|
||||
"invalid": {
|
||||
ip: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
"empty": {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := IPAddress(tc.ip).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDR(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
cidr string
|
||||
wantErr bool
|
||||
}{
|
||||
"valid ipv4": {
|
||||
cidr: "192.0.2.0/24",
|
||||
},
|
||||
"valid ipv6": {
|
||||
cidr: "2001:db8::/32",
|
||||
},
|
||||
"invalid": {
|
||||
cidr: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
"empty": {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := CIDR(tc.cidr).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSName(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
dnsName string
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
dnsName: "example.com",
|
||||
},
|
||||
"invalid": {
|
||||
dnsName: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
"empty": {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := DNSName(tc.dnsName).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptySlice(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s []any
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
s: []any{},
|
||||
},
|
||||
"nil": {
|
||||
s: nil,
|
||||
},
|
||||
"invalid": {
|
||||
s: []any{1},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := EmptySlice(tc.s).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEmptySlice(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s []any
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
s: []any{1},
|
||||
},
|
||||
"invalid": {
|
||||
s: []any{},
|
||||
wantErr: true,
|
||||
},
|
||||
"nil": {
|
||||
s: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := NotEmptySlice(tc.s).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
c := func(i int, s string) *Constraint {
|
||||
return Equal(s, "abc")
|
||||
}
|
||||
testCases := map[string]struct {
|
||||
s []string
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
s: []string{"abc", "abc", "abc"},
|
||||
},
|
||||
"nil": {
|
||||
s: nil,
|
||||
},
|
||||
"empty": {
|
||||
s: []string{},
|
||||
},
|
||||
"all are invalid": {
|
||||
s: []string{"def", "lol"},
|
||||
wantErr: true,
|
||||
},
|
||||
"one is invalid": {
|
||||
s: []string{"abc", "def"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := All(tc.s, c).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEqual(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
a any
|
||||
b any
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
a: "abc",
|
||||
b: "def",
|
||||
},
|
||||
"invalid": {
|
||||
a: "abc",
|
||||
b: "abc",
|
||||
wantErr: true,
|
||||
},
|
||||
"empty": {
|
||||
wantErr: true,
|
||||
},
|
||||
"one empty": {
|
||||
a: "abc",
|
||||
b: "",
|
||||
},
|
||||
"one nil": {
|
||||
a: "abc",
|
||||
b: nil,
|
||||
},
|
||||
"both nil": {
|
||||
a: nil,
|
||||
b: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := NotEqual(tc.a, tc.b).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIfNotNil(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
a *int
|
||||
c func() *Constraint
|
||||
wantErr bool
|
||||
}{
|
||||
"valid": {
|
||||
a: new(int),
|
||||
c: func() *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"nil": {
|
||||
a: nil,
|
||||
c: func() *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() *TreeError {
|
||||
t.Fatal("should not be called")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := IfNotNil(tc.a, tc.c).Satisfied()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -13,42 +13,86 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Error is returned when a document is not valid.
|
||||
type Error struct {
|
||||
Path string
|
||||
Err error
|
||||
// TreeError is returned when a document is not valid.
|
||||
// It contains the path to the field that failed validation, the error
|
||||
// that occurred, as well as a list of child errors, as one constraint
|
||||
// can embed multiple other constraints, e.g. in an OR.
|
||||
type TreeError struct {
|
||||
path string
|
||||
err error
|
||||
children []*TreeError
|
||||
}
|
||||
|
||||
// NewErrorTree creates a new error tree from the given error.
|
||||
func NewErrorTree(err error) *TreeError {
|
||||
return &TreeError{
|
||||
err: err,
|
||||
children: []*TreeError{},
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
newError creates a new validation Error.
|
||||
newTraceError creates a new validation error, traced to a field.
|
||||
|
||||
To find the path to the exported field that failed validation, it traverses "doc"
|
||||
recursively until it finds a field in "doc" that matches the reference to "field".
|
||||
*/
|
||||
func newError(doc, field referenceableValue, errMsg error) *Error {
|
||||
func newTraceError(doc, field referenceableValue, errMsg error) *TreeError {
|
||||
// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found
|
||||
path, err := traverse(doc, field, newPathBuilder(doc._type.Name()))
|
||||
if err != nil {
|
||||
return &Error{
|
||||
Path: "unknown",
|
||||
Err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
|
||||
return &TreeError{
|
||||
path: "unknown",
|
||||
err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
|
||||
}
|
||||
}
|
||||
|
||||
return &Error{
|
||||
Path: path,
|
||||
Err: errMsg,
|
||||
return &TreeError{
|
||||
path: path,
|
||||
err: errMsg,
|
||||
children: []*TreeError{},
|
||||
}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("validating %s: %s", e.Path, e.Err)
|
||||
func (e *TreeError) Error() string {
|
||||
return e.format(0)
|
||||
}
|
||||
|
||||
// Unwrap implements the error interface.
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Err
|
||||
func (e *TreeError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// format formats the error tree and all of its children.
|
||||
func (e *TreeError) format(indent int) string {
|
||||
var sb strings.Builder
|
||||
if e.path != "" {
|
||||
sb.WriteString(fmt.Sprintf(
|
||||
"%svalidating %s: %s",
|
||||
strings.Repeat(" ", indent),
|
||||
e.path,
|
||||
e.err,
|
||||
))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(
|
||||
"%s%s",
|
||||
strings.Repeat(" ", indent),
|
||||
e.err,
|
||||
))
|
||||
}
|
||||
for _, child := range e.children {
|
||||
sb.WriteString(fmt.Sprintf(
|
||||
"\n%s",
|
||||
child.format(indent+1),
|
||||
))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// appendChild adds the given child error to the tree.
|
||||
func (e *TreeError) appendChild(child *TreeError) {
|
||||
e.children = append(e.children, child)
|
||||
}
|
||||
|
||||
/*
|
||||
@ -238,9 +282,13 @@ func newPathBuilder(topLevelDoc string) pathBuilder {
|
||||
func (p pathBuilder) appendStructField(field reflect.StructField) pathBuilder {
|
||||
switch {
|
||||
case field.Tag.Get("json") != "":
|
||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("json")))
|
||||
// cut off omitempty or other options
|
||||
jsonTagName, _, _ := strings.Cut(field.Tag.Get("json"), ",")
|
||||
p.buf = append(p.buf, fmt.Sprintf(".%s", jsonTagName))
|
||||
case field.Tag.Get("yaml") != "":
|
||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("yaml")))
|
||||
// cut off omitempty or other options
|
||||
yamlTagName, _, _ := strings.Cut(field.Tag.Get("yaml"), ",")
|
||||
p.buf = append(p.buf, fmt.Sprintf(".%s", yamlTagName))
|
||||
default:
|
||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Name))
|
||||
}
|
||||
|
@ -15,6 +15,37 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorFormatting(t *testing.T) {
|
||||
err := &TreeError{
|
||||
path: "path",
|
||||
err: fmt.Errorf("error"),
|
||||
children: []*TreeError{},
|
||||
}
|
||||
|
||||
assert.Equal(t, "validating path: error", err.Error())
|
||||
|
||||
err.children = append(err.children, &TreeError{
|
||||
path: "child",
|
||||
err: fmt.Errorf("child error"),
|
||||
children: []*TreeError{},
|
||||
})
|
||||
|
||||
assert.Equal(t, "validating path: error\n validating child: child error", err.Error())
|
||||
|
||||
err.children = append(err.children, &TreeError{
|
||||
path: "child2",
|
||||
err: fmt.Errorf("child2 error"),
|
||||
children: []*TreeError{
|
||||
{
|
||||
path: "child2child",
|
||||
err: fmt.Errorf("child2child error"),
|
||||
children: []*TreeError{},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Equal(t, "validating path: error\n validating child: child error\n validating child2: child2 error\n validating child2child: child2child error", err.Error())
|
||||
}
|
||||
|
||||
// Tests for primitive / shallow fields
|
||||
|
||||
func TestNewValidationErrorSingleField(t *testing.T) {
|
||||
@ -24,7 +55,7 @@ func TestNewValidationErrorSingleField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.otherField: %s", assert.AnError))
|
||||
}
|
||||
@ -37,7 +68,7 @@ func TestNewValidationErrorSingleFieldPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.PointerField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.pointerField: %s", assert.AnError))
|
||||
}
|
||||
@ -51,7 +82,7 @@ func TestNewValidationErrorSingleFieldDoublePtr(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.DoublePointerField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.doublePointerField: %s", assert.AnError))
|
||||
}
|
||||
@ -66,7 +97,7 @@ func TestNewValidationErrorSingleFieldInexistent(t *testing.T) {
|
||||
inexistentField := 123
|
||||
|
||||
doc, field := references(t, st, &inexistentField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot find path to field: cannot traverse anymore")
|
||||
}
|
||||
@ -84,7 +115,7 @@ func TestNewValidationErrorNestedField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedField.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.otherField: %s", assert.AnError))
|
||||
@ -102,7 +133,7 @@ func TestNewValidationErrorPointerInNestedField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedField.PointerField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.pointerField: %s", assert.AnError))
|
||||
@ -123,7 +154,7 @@ func TestNewValidationErrorNestedFieldPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedPointerField.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.otherField: %s", assert.AnError))
|
||||
@ -144,7 +175,7 @@ func TestNewValidationErrorNestedNestedField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedField.NestedField.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedField.otherField: %s", assert.AnError))
|
||||
@ -165,7 +196,7 @@ func TestNewValidationErrorNestedNestedFieldPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedField.NestedPointerField.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedPointerField.otherField: %s", assert.AnError))
|
||||
@ -186,7 +217,7 @@ func TestNewValidationErrorNestedPtrNestedFieldPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NestedPointerField.NestedPointerField.OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.nestedPointerField.otherField: %s", assert.AnError))
|
||||
@ -200,7 +231,7 @@ func TestNewValidationErrorPrimitiveSlice(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.PrimitiveSlice[1], "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSlice[1]: %s", assert.AnError))
|
||||
@ -212,7 +243,7 @@ func TestNewValidationErrorPrimitiveArray(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.PrimitiveArray[1], "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveArray[1]: %s", assert.AnError))
|
||||
@ -233,7 +264,7 @@ func TestNewValidationErrorStructSlice(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.StructSlice[1].OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structSlice[1].otherField: %s", assert.AnError))
|
||||
@ -254,7 +285,7 @@ func TestNewValidationErrorStructArray(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.StructArray[1].OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structArray[1].otherField: %s", assert.AnError))
|
||||
@ -275,7 +306,7 @@ func TestNewValidationErrorStructPointerSlice(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.StructPointerSlice[1].OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerSlice[1].otherField: %s", assert.AnError))
|
||||
@ -296,7 +327,7 @@ func TestNewValidationErrorStructPointerArray(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.StructPointerArray[1].OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerArray[1].otherField: %s", assert.AnError))
|
||||
@ -311,7 +342,7 @@ func TestNewValidationErrorPrimitiveSliceSlice(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.PrimitiveSliceSlice[1][1], "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSliceSlice[1][1]: %s", assert.AnError))
|
||||
@ -328,7 +359,7 @@ func TestNewValidationErrorPrimitiveMap(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.PrimitiveMap, "ghi")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.primitiveMap[\"ghi\"]: %s", assert.AnError))
|
||||
@ -349,7 +380,7 @@ func TestNewValidationErrorStructPointerMap(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.StructPointerMap["ghi"].OtherField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.structPointerMap[\"ghi\"].otherField: %s", assert.AnError))
|
||||
@ -368,7 +399,7 @@ func TestNewValidationErrorNestedPrimitiveMap(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, st.NestedPointerMap["jkl"], "mno")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
t.Log(err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.nestedPointerMap[\"jkl\"][\"mno\"]: %s", assert.AnError))
|
||||
@ -383,7 +414,7 @@ func TestNewValidationErrorTopLevelIsNeedle(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, st, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc: %s", assert.AnError))
|
||||
}
|
||||
@ -396,7 +427,7 @@ func TestNewValidationErrorUntaggedField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.NoTagField, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.NoTagField: %s", assert.AnError))
|
||||
}
|
||||
@ -410,7 +441,7 @@ func TestNewValidationErrorOnlyYamlTaggedField(t *testing.T) {
|
||||
}
|
||||
|
||||
doc, field := references(t, st, &st.OnlyYamlKey, "")
|
||||
err := newError(doc, field, assert.AnError)
|
||||
err := newTraceError(doc, field, assert.AnError)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.onlyYamlKey: %s", assert.AnError))
|
||||
}
|
||||
|
@ -11,7 +11,19 @@ It validates documents that specify a set of constraints on their content.
|
||||
*/
|
||||
package validation
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrStrategy is the strategy to use when encountering an error during validation.
|
||||
type ErrStrategy int
|
||||
|
||||
const (
|
||||
// EvaluateAll continues evaluating all constraints even if one is not satisfied.
|
||||
EvaluateAll ErrStrategy = iota
|
||||
// FailFast stops validation on the first error.
|
||||
FailFast
|
||||
)
|
||||
|
||||
// NewValidator creates a new Validator.
|
||||
func NewValidator() *Validator {
|
||||
@ -24,21 +36,31 @@ type Validator struct{}
|
||||
// Validatable is implemented by documents that can be validated.
|
||||
// It returns a list of constraints that must be satisfied for the document to be valid.
|
||||
type Validatable interface {
|
||||
Constraints() []Constraint
|
||||
Constraints() []*Constraint
|
||||
}
|
||||
|
||||
// ValidateOptions are the options to use when validating a document.
|
||||
type ValidateOptions struct {
|
||||
// FailFast stops validation on the first error.
|
||||
FailFast bool
|
||||
// ErrStrategy is the strategy to use when encountering an error during validation.
|
||||
ErrStrategy ErrStrategy
|
||||
// OverrideConstraints overrides the constraints to use for validation.
|
||||
// If nil, the constraints returned by the document are used.
|
||||
OverrideConstraints func() []*Constraint
|
||||
}
|
||||
|
||||
// Validate validates a document using the given options.
|
||||
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error {
|
||||
var constraints func() []*Constraint
|
||||
if opts.OverrideConstraints != nil {
|
||||
constraints = opts.OverrideConstraints
|
||||
} else {
|
||||
constraints = doc.Constraints
|
||||
}
|
||||
|
||||
var retErr error
|
||||
for _, c := range doc.Constraints() {
|
||||
for _, c := range constraints() {
|
||||
if err := c.Satisfied(); err != nil {
|
||||
if opts.FailFast {
|
||||
if opts.ErrStrategy == FailFast {
|
||||
return err
|
||||
}
|
||||
retErr = errors.Join(retErr, err)
|
||||
|
@ -14,34 +14,39 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var validDoc = func() *exampleDoc {
|
||||
return &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
OneOfField: "one",
|
||||
OrLeftField: "left",
|
||||
OrRightField: "right",
|
||||
AndLeftField: "left",
|
||||
AndRightField: "right",
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
doc Validatable
|
||||
doc func() *exampleDoc
|
||||
opts ValidateOptions
|
||||
wantErr bool
|
||||
errAssertion func(*assert.Assertions, error) bool
|
||||
}{
|
||||
"valid": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
},
|
||||
doc: validDoc,
|
||||
opts: ValidateOptions{},
|
||||
},
|
||||
"strField is not abc": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "def",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.StrField = "def"
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
@ -50,14 +55,10 @@ func TestValidate(t *testing.T) {
|
||||
opts: ValidateOptions{},
|
||||
},
|
||||
"numField is not 42": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 43,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.NumField = 43
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
@ -65,14 +66,11 @@ func TestValidate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"multiple errors": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "def",
|
||||
NumField: 43,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.StrField = "def"
|
||||
doc.NumField = 43
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
@ -82,75 +80,108 @@ func TestValidate(t *testing.T) {
|
||||
opts: ValidateOptions{},
|
||||
},
|
||||
"multiple errors, fail fast": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "def",
|
||||
NumField: 43,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.StrField = "def"
|
||||
doc.NumField = 43
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
FailFast: true,
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"map field is not empty": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.MapField = &map[string]string{
|
||||
"empty": "haha!",
|
||||
},
|
||||
NotEmptyField: "certainly not.",
|
||||
MatchRegexField: "abc",
|
||||
}
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.mapField[\"empty\"]: haha! must be empty")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
FailFast: true,
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"empty field is not empty": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "",
|
||||
MatchRegexField: "abc",
|
||||
"not empty field is empty": {
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.NotEmptyField = ""
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
FailFast: true,
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"regex doesnt match": {
|
||||
doc: &exampleDoc{
|
||||
StrField: "abc",
|
||||
NumField: 42,
|
||||
MapField: &map[string]string{
|
||||
"empty": "",
|
||||
},
|
||||
NotEmptyField: "certainly not!",
|
||||
MatchRegexField: "dontmatch",
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.MatchRegexField = "dontmatch"
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.matchRegexField: dontmatch must match the pattern ^a.c$")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
FailFast: true,
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"field is not in 'oneof' values": {
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.OneOfField = "not in oneof"
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "validating exampleDoc.oneOfField: not in oneof must be one of [one two three]")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"'or' violated": {
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.OrLeftField = "not left"
|
||||
doc.OrRightField = "not right"
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "at least one of the constraints must be satisfied:") &&
|
||||
assert.Contains(err.Error(), "validating exampleDoc.orLeftField: not left must be equal to left") &&
|
||||
assert.Contains(err.Error(), "validating exampleDoc.orRightField: not right must be equal to right")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
"'and' violated": {
|
||||
doc: func() *exampleDoc {
|
||||
doc := validDoc()
|
||||
doc.AndRightField = "not right"
|
||||
return doc
|
||||
},
|
||||
wantErr: true,
|
||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||
return assert.Contains(err.Error(), "all of the constraints must be satisfied:") &&
|
||||
assert.Contains(err.Error(), "validating exampleDoc.andRightField: not right must be equal to right")
|
||||
},
|
||||
opts: ValidateOptions{
|
||||
ErrStrategy: FailFast,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -160,7 +191,7 @@ func TestValidate(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
err := NewValidator().Validate(tc.doc, tc.opts)
|
||||
err := NewValidator().Validate(tc.doc(), tc.opts)
|
||||
if tc.wantErr {
|
||||
require.Error(err)
|
||||
if !tc.errAssertion(assert, err) {
|
||||
@ -179,13 +210,18 @@ type exampleDoc struct {
|
||||
MapField *map[string]string `json:"mapField"`
|
||||
NotEmptyField string `json:"notEmptyField"`
|
||||
MatchRegexField string `json:"matchRegexField"`
|
||||
OneOfField string `json:"oneOfField"`
|
||||
OrLeftField string `json:"orLeftField"`
|
||||
OrRightField string `json:"orRightField"`
|
||||
AndLeftField string `json:"andLeftField"`
|
||||
AndRightField string `json:"andRightField"`
|
||||
}
|
||||
|
||||
// Constraints implements the Validatable interface.
|
||||
func (d *exampleDoc) Constraints() []Constraint {
|
||||
func (d *exampleDoc) Constraints() []*Constraint {
|
||||
mapField := *(d.MapField)
|
||||
|
||||
return []Constraint{
|
||||
return []*Constraint{
|
||||
d.strFieldNeedsToBeAbc().
|
||||
WithFieldTrace(d, &d.StrField),
|
||||
Equal(d.NumField, 42).
|
||||
@ -196,17 +232,95 @@ func (d *exampleDoc) Constraints() []Constraint {
|
||||
WithFieldTrace(d, &d.NotEmptyField),
|
||||
MatchRegex(d.MatchRegexField, "^a.c$").
|
||||
WithFieldTrace(d, &d.MatchRegexField),
|
||||
OneOf(d.OneOfField, []string{"one", "two", "three"}).
|
||||
WithFieldTrace(d, &d.OneOfField),
|
||||
Or(
|
||||
Equal(d.OrLeftField, "left").
|
||||
WithFieldTrace(d, &d.OrLeftField),
|
||||
Equal(d.OrRightField, "right").
|
||||
WithFieldTrace(d, &d.OrRightField),
|
||||
),
|
||||
And(
|
||||
EvaluateAll,
|
||||
Equal(d.AndLeftField, "left").
|
||||
WithFieldTrace(d, &d.AndLeftField),
|
||||
Equal(d.AndRightField, "right").
|
||||
WithFieldTrace(d, &d.AndRightField),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// StrFieldNeedsToBeAbc is an example for a custom constraint.
|
||||
func (d *exampleDoc) strFieldNeedsToBeAbc() *Constraint {
|
||||
return &Constraint{
|
||||
Satisfied: func() error {
|
||||
Satisfied: func() *TreeError {
|
||||
if d.StrField != "abc" {
|
||||
return fmt.Errorf("%s must be abc", d.StrField)
|
||||
return NewErrorTree(
|
||||
fmt.Errorf("%s must be abc", d.StrField),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverrideConstraints(t *testing.T) {
|
||||
overrideConstraints := func(t *testing.T, wantCalled bool) func() []*Constraint {
|
||||
return func() []*Constraint {
|
||||
if !wantCalled {
|
||||
t.Fatal("overrideConstraints should not be called")
|
||||
}
|
||||
return []*Constraint{}
|
||||
}
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
doc exampleDocToOverride
|
||||
overrideFunc func() []*Constraint
|
||||
wantOverrideCalled bool
|
||||
wantErr bool
|
||||
}{
|
||||
"override constraints": {
|
||||
doc: exampleDocToOverride{},
|
||||
overrideFunc: overrideConstraints(t, true),
|
||||
wantOverrideCalled: true,
|
||||
},
|
||||
"do not override constraints": {
|
||||
doc: exampleDocToOverride{
|
||||
calledDocConstraints: true,
|
||||
},
|
||||
overrideFunc: nil,
|
||||
wantOverrideCalled: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
validator := NewValidator()
|
||||
err := validator.Validate(&tc.doc, ValidateOptions{
|
||||
OverrideConstraints: tc.overrideFunc,
|
||||
})
|
||||
|
||||
if tc.wantErr {
|
||||
require.Error(err)
|
||||
} else {
|
||||
require.NoError(err)
|
||||
if tc.wantOverrideCalled {
|
||||
assert.Equal(tc.doc.calledDocConstraints, false)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type exampleDocToOverride struct {
|
||||
calledDocConstraints bool
|
||||
}
|
||||
|
||||
func (d *exampleDocToOverride) Constraints() []*Constraint {
|
||||
d.calledDocConstraints = true
|
||||
return []*Constraint{}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user