mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-05-08 01:05:16 -04:00
cli: state file validation (#2523)
* re-use `ReadFromFile` in `CreateOrRead` Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip]: add constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip] error formatting Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * wip Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * formatted error messages Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * state file validation Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * linter fixes Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * allow overriding the constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * dont validate on read Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * add pre-create constraints Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * [wip] Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * finish pre-init validation test Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * finish post-init validation Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * use state file validation in CLI Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * fix apply tests Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * Update internal/validation/errors.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * use transformator for tests * tidy * use empty check directly Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * Update cli/internal/state/state.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * conditional validation per CSP Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * tidy Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * fix rebase Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * add default case Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * validate state-file as last input Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> --------- Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
This commit is contained in:
parent
eaec73cca4
commit
744a605602
21 changed files with 1779 additions and 247 deletions
|
@ -419,12 +419,6 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
|
||||||
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check license
|
// Check license
|
||||||
a.log.Debugf("Running license check")
|
a.log.Debugf("Running license check")
|
||||||
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
|
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
|
||||||
|
@ -517,6 +511,27 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
|
||||||
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
|
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read and validate state file
|
||||||
|
// This needs to be done as a last step, as we need to parse all other inputs to
|
||||||
|
// know which phases are skipped.
|
||||||
|
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
|
||||||
|
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if a.flags.skipPhases.contains(skipInitPhase) {
|
||||||
|
// If the skipInit flag is set, we are in a state where the cluster
|
||||||
|
// has already been initialized and check against the respective constraints.
|
||||||
|
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The cluster has not been initialized yet, so we check against the pre-init constraints.
|
||||||
|
if err := stateFile.Validate(state.PreInit, conf.GetProvider()); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return conf, stateFile, nil
|
return conf, stateFile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
|
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
|
||||||
|
"github.com/edgelesssys/constellation/v2/cli/internal/state"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
|
@ -22,6 +23,54 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// defaultStateFile returns a valid default state for testing.
|
||||||
|
func defaultStateFile() *state.State {
|
||||||
|
return &state.State{
|
||||||
|
Version: "v1",
|
||||||
|
Infrastructure: state.Infrastructure{
|
||||||
|
UID: "123",
|
||||||
|
Name: "test-cluster",
|
||||||
|
ClusterEndpoint: "192.0.2.1",
|
||||||
|
InClusterEndpoint: "192.0.2.1",
|
||||||
|
InitSecret: []byte{0x41},
|
||||||
|
APIServerCertSANs: []string{
|
||||||
|
"127.0.0.1",
|
||||||
|
"www.example.com",
|
||||||
|
},
|
||||||
|
IPCidrNode: "0.0.0.0/24",
|
||||||
|
Azure: &state.Azure{
|
||||||
|
ResourceGroup: "test-rg",
|
||||||
|
SubscriptionID: "test-sub",
|
||||||
|
NetworkSecurityGroupName: "test-nsg",
|
||||||
|
LoadBalancerName: "test-lb",
|
||||||
|
UserAssignedIdentity: "test-uami",
|
||||||
|
AttestationURL: "test-maaUrl",
|
||||||
|
},
|
||||||
|
GCP: &state.GCP{
|
||||||
|
ProjectID: "test-project",
|
||||||
|
IPCidrPod: "0.0.0.0/24",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ClusterValues: state.ClusterValues{
|
||||||
|
ClusterID: "deadbeef",
|
||||||
|
OwnerID: "deadbeef",
|
||||||
|
MeasurementSalt: []byte{0x41},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultAzureStateFile() *state.State {
|
||||||
|
s := defaultStateFile()
|
||||||
|
s.Infrastructure.GCP = nil
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultGCPStateFile() *state.State {
|
||||||
|
s := defaultStateFile()
|
||||||
|
s.Infrastructure.Azure = nil
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseApplyFlags(t *testing.T) {
|
func TestParseApplyFlags(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
defaultFlags := func() *pflag.FlagSet {
|
defaultFlags := func() *pflag.FlagSet {
|
||||||
|
|
|
@ -202,6 +202,9 @@ func (c *createCmd) create(cmd *cobra.Command, applier cloudApplier, fileHandler
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading state file: %w", err)
|
return fmt.Errorf("reading state file: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := stateFile.Validate(state.PreCreate, conf.GetProvider()); err != nil {
|
||||||
|
return fmt.Errorf("validating state file: %w", err)
|
||||||
|
}
|
||||||
stateFile = stateFile.SetInfrastructure(infraState)
|
stateFile = stateFile.SetInfrastructure(infraState)
|
||||||
if err := stateFile.WriteToFile(fileHandler, constants.StateFilename); err != nil {
|
if err := stateFile.WriteToFile(fileHandler, constants.StateFilename); err != nil {
|
||||||
return fmt.Errorf("writing state file: %w", err)
|
return fmt.Errorf("writing state file: %w", err)
|
||||||
|
|
|
@ -22,12 +22,21 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// preCreateStateFile returns a state file satisfying the pre-create state file
|
||||||
|
// constraints.
|
||||||
|
func preCreateStateFile() *state.State {
|
||||||
|
s := defaultAzureStateFile()
|
||||||
|
s.ClusterValues = state.ClusterValues{}
|
||||||
|
s.Infrastructure = state.Infrastructure{}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreate(t *testing.T) {
|
func TestCreate(t *testing.T) {
|
||||||
fsWithDefaultConfigAndState := func(require *require.Assertions, provider cloudprovider.Provider) afero.Fs {
|
fsWithDefaultConfigAndState := func(require *require.Assertions, provider cloudprovider.Provider) afero.Fs {
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
file := file.NewHandler(fs)
|
file := file.NewHandler(fs)
|
||||||
require.NoError(file.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), provider)))
|
require.NoError(file.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), provider)))
|
||||||
stateFile := state.New()
|
stateFile := preCreateStateFile()
|
||||||
switch provider {
|
switch provider {
|
||||||
case cloudprovider.GCP:
|
case cloudprovider.GCP:
|
||||||
stateFile.SetInfrastructure(state.Infrastructure{GCP: &state.GCP{}})
|
stateFile.SetInfrastructure(state.Infrastructure{GCP: &state.GCP{}})
|
||||||
|
|
|
@ -59,6 +59,14 @@ func TestInitArgumentValidation(t *testing.T) {
|
||||||
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
|
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// preInitStateFile returns a state file satisfying the pre-init state file
|
||||||
|
// constraints.
|
||||||
|
func preInitStateFile() *state.State {
|
||||||
|
s := defaultAzureStateFile()
|
||||||
|
s.ClusterValues = state.ClusterValues{}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func TestInitialize(t *testing.T) {
|
func TestInitialize(t *testing.T) {
|
||||||
respKubeconfig := k8sclientapi.Config{
|
respKubeconfig := k8sclientapi.Config{
|
||||||
Clusters: map[string]*k8sclientapi.Cluster{
|
Clusters: map[string]*k8sclientapi.Cluster{
|
||||||
|
@ -101,24 +109,24 @@ func TestInitialize(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
"initialize some gcp instances": {
|
"initialize some gcp instances": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
serviceAccKey: gcpServiceAccKey,
|
serviceAccKey: gcpServiceAccKey,
|
||||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||||
},
|
},
|
||||||
"initialize some azure instances": {
|
"initialize some azure instances": {
|
||||||
provider: cloudprovider.Azure,
|
provider: cloudprovider.Azure,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||||
},
|
},
|
||||||
"initialize some qemu instances": {
|
"initialize some qemu instances": {
|
||||||
provider: cloudprovider.QEMU,
|
provider: cloudprovider.QEMU,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||||
},
|
},
|
||||||
"non retriable error": {
|
"non retriable error": {
|
||||||
provider: cloudprovider.QEMU,
|
provider: cloudprovider.QEMU,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{initErr: &nonRetriableError{err: assert.AnError}},
|
initServerAPI: &stubInitServer{initErr: &nonRetriableError{err: assert.AnError}},
|
||||||
retriable: false,
|
retriable: false,
|
||||||
masterSecretShouldExist: true,
|
masterSecretShouldExist: true,
|
||||||
|
@ -126,7 +134,7 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"non retriable error with failed log collection": {
|
"non retriable error with failed log collection": {
|
||||||
provider: cloudprovider.QEMU,
|
provider: cloudprovider.QEMU,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{
|
initServerAPI: &stubInitServer{
|
||||||
res: []*initproto.InitResponse{
|
res: []*initproto.InitResponse{
|
||||||
{
|
{
|
||||||
|
@ -149,31 +157,24 @@ func TestInitialize(t *testing.T) {
|
||||||
masterSecretShouldExist: true,
|
masterSecretShouldExist: true,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
/*
|
"invalid state file": {
|
||||||
Tests currently disabled since we don't actually have validation for the state file yet
|
provider: cloudprovider.GCP,
|
||||||
These tests cases only passed in the past because of unrelated errors in the test setup
|
stateFile: &state.State{Version: "invalid"},
|
||||||
TODO(AB#3492): Re-enable tests once state file validation is implemented
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
|
serviceAccKey: gcpServiceAccKey,
|
||||||
"state file with only version": {
|
initServerAPI: &stubInitServer{},
|
||||||
provider: cloudprovider.GCP,
|
retriable: true,
|
||||||
stateFile: &state.State{Version: state.Version1},
|
wantErr: true,
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
},
|
||||||
serviceAccKey: gcpServiceAccKey,
|
"empty state file": {
|
||||||
initServerAPI: &stubInitServer{},
|
provider: cloudprovider.GCP,
|
||||||
retriable: true,
|
stateFile: &state.State{},
|
||||||
wantErr: true,
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
},
|
serviceAccKey: gcpServiceAccKey,
|
||||||
|
initServerAPI: &stubInitServer{},
|
||||||
"empty state file": {
|
retriable: true,
|
||||||
provider: cloudprovider.GCP,
|
wantErr: true,
|
||||||
stateFile: &state.State{},
|
},
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
|
||||||
serviceAccKey: gcpServiceAccKey,
|
|
||||||
initServerAPI: &stubInitServer{},
|
|
||||||
retriable: true,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
*/
|
|
||||||
"no state file": {
|
"no state file": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
|
@ -184,7 +185,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"init call fails": {
|
"init call fails": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
serviceAccKey: gcpServiceAccKey,
|
serviceAccKey: gcpServiceAccKey,
|
||||||
initServerAPI: &stubInitServer{initErr: assert.AnError},
|
initServerAPI: &stubInitServer{initErr: assert.AnError},
|
||||||
retriable: false,
|
retriable: false,
|
||||||
|
@ -193,7 +194,7 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"k8s version without v works": {
|
"k8s version without v works": {
|
||||||
provider: cloudprovider.Azure,
|
provider: cloudprovider.Azure,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||||
configMutator: func(c *config.Config) {
|
configMutator: func(c *config.Config) {
|
||||||
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
|
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
|
||||||
|
@ -203,7 +204,7 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"outdated k8s patch version doesn't work": {
|
"outdated k8s patch version doesn't work": {
|
||||||
provider: cloudprovider.Azure,
|
provider: cloudprovider.Azure,
|
||||||
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: preInitStateFile(),
|
||||||
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
|
||||||
configMutator: func(c *config.Config) {
|
configMutator: func(c *config.Config) {
|
||||||
v, err := semver.New(versions.SupportedK8sVersions()[0])
|
v, err := semver.New(versions.SupportedK8sVersions()[0])
|
||||||
|
|
|
@ -119,6 +119,10 @@ func (r *recoverCmd) recover(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading state file: %w", err)
|
return fmt.Errorf("reading state file: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := stateFile.Validate(state.PostInit, provider); err != nil {
|
||||||
|
return fmt.Errorf("validating state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
endpoint, err := r.parseEndpoint(stateFile)
|
endpoint, err := r.parseEndpoint(stateFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/v2/cli/internal/state"
|
|
||||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||||
|
@ -159,7 +158,7 @@ func TestRecover(t *testing.T) {
|
||||||
))
|
))
|
||||||
require.NoError(fileHandler.WriteYAML(
|
require.NoError(fileHandler.WriteYAML(
|
||||||
constants.StateFilename,
|
constants.StateFilename,
|
||||||
state.New(),
|
defaultGCPStateFile(),
|
||||||
file.OptNone,
|
file.OptNone,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
@ -33,18 +33,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpgradeApply(t *testing.T) {
|
func TestUpgradeApply(t *testing.T) {
|
||||||
defaultState := state.New().
|
|
||||||
SetInfrastructure(state.Infrastructure{
|
|
||||||
APIServerCertSANs: []string{},
|
|
||||||
UID: "uid",
|
|
||||||
Name: "kubernetes-uid", // default test cfg uses "kubernetes" prefix
|
|
||||||
InitSecret: []byte{0x42},
|
|
||||||
}).
|
|
||||||
SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}})
|
|
||||||
fsWithStateFileAndTfState := func() file.Handler {
|
fsWithStateFileAndTfState := func() file.Handler {
|
||||||
fh := file.NewHandler(afero.NewMemMapFs())
|
fh := file.NewHandler(afero.NewMemMapFs())
|
||||||
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
|
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
|
||||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
|
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
|
||||||
return fh
|
return fh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,20 +55,20 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
|
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
|
||||||
gotState, err := state.ReadFromFile(fh, constants.StateFilename)
|
gotState, err := state.ReadFromFile(fh, constants.StateFilename)
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
assert.Equal("v1", gotState.Version)
|
assert.Equal("v1", gotState.Version)
|
||||||
assert.Equal(defaultState, gotState)
|
assert.Equal(defaultAzureStateFile(), gotState)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"id file and state file do not exist": {
|
"id file and state file do not exist": {
|
||||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: func() file.Handler {
|
fh: func() file.Handler {
|
||||||
return file.NewHandler(afero.NewMemMapFs())
|
return file.NewHandler(afero.NewMemMapFs())
|
||||||
},
|
},
|
||||||
|
@ -90,7 +82,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"nodeVersion in progress error": {
|
"nodeVersion in progress error": {
|
||||||
|
@ -100,7 +92,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
},
|
},
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"helm other error": {
|
"helm other error": {
|
||||||
|
@ -110,7 +102,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
helmUpgrader: stubApplier{err: assert.AnError},
|
helmUpgrader: stubApplier{err: assert.AnError},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
|
@ -140,7 +132,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
|
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"apply terraform error": {
|
"apply terraform error": {
|
||||||
|
@ -153,7 +145,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
terraformDiff: true,
|
terraformDiff: true,
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"outdated K8s patch version": {
|
"outdated K8s patch version": {
|
||||||
|
@ -167,7 +159,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
|
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
|
||||||
}(),
|
}(),
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
"outdated K8s version": {
|
"outdated K8s version": {
|
||||||
|
@ -177,7 +169,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
customK8sVersion: "v1.20.0",
|
customK8sVersion: "v1.20.0",
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
},
|
},
|
||||||
|
@ -191,6 +183,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
skipPhases: skipPhases{
|
skipPhases: skipPhases{
|
||||||
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
||||||
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
|
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
|
||||||
|
skipInitPhase: struct{}{},
|
||||||
},
|
},
|
||||||
yes: true,
|
yes: true,
|
||||||
},
|
},
|
||||||
|
@ -205,7 +198,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
flags: applyFlags{
|
flags: applyFlags{
|
||||||
skipPhases: skipPhases{
|
skipPhases: skipPhases{
|
||||||
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
|
||||||
skipK8sPhase: struct{}{},
|
skipK8sPhase: struct{}{}, skipInitPhase: struct{}{},
|
||||||
},
|
},
|
||||||
yes: true,
|
yes: true,
|
||||||
},
|
},
|
||||||
|
@ -219,10 +212,13 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
terraformUpgrader: &mockTerraformUpgrader{},
|
terraformUpgrader: &mockTerraformUpgrader{},
|
||||||
flags: applyFlags{
|
flags: applyFlags{
|
||||||
yes: true,
|
yes: true,
|
||||||
|
skipPhases: skipPhases{
|
||||||
|
skipInitPhase: struct{}{},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
fh: func() file.Handler {
|
fh: func() file.Handler {
|
||||||
fh := file.NewHandler(afero.NewMemMapFs())
|
fh := file.NewHandler(afero.NewMemMapFs())
|
||||||
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
|
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
|
||||||
return fh
|
return fh
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -230,7 +226,7 @@ func TestUpgradeApply(t *testing.T) {
|
||||||
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
|
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
|
||||||
helmUpgrader: stubApplier{},
|
helmUpgrader: stubApplier{},
|
||||||
terraformUpgrader: &stubTerraformUpgrader{},
|
terraformUpgrader: &stubTerraformUpgrader{},
|
||||||
flags: applyFlags{yes: true},
|
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
|
||||||
fh: fsWithStateFileAndTfState,
|
fh: fsWithStateFileAndTfState,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
|
|
@ -155,6 +155,9 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading state file: %w", err)
|
return fmt.Errorf("reading state file: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
|
||||||
|
return fmt.Errorf("validating state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
|
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -48,7 +48,7 @@ func TestVerify(t *testing.T) {
|
||||||
formatter *stubAttDocFormatter
|
formatter *stubAttDocFormatter
|
||||||
nodeEndpointFlag string
|
nodeEndpointFlag string
|
||||||
clusterIDFlag string
|
clusterIDFlag string
|
||||||
stateFile *state.State
|
stateFile func() *state.State
|
||||||
wantEndpoint string
|
wantEndpoint string
|
||||||
skipConfigCreation bool
|
skipConfigCreation bool
|
||||||
wantErr bool
|
wantErr bool
|
||||||
|
@ -58,7 +58,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: defaultGCPStateFile,
|
||||||
wantEndpoint: "192.0.2.1:1234",
|
wantEndpoint: "192.0.2.1:1234",
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
|
@ -67,7 +67,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: defaultAzureStateFile,
|
||||||
wantEndpoint: "192.0.2.1:1234",
|
wantEndpoint: "192.0.2.1:1234",
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
|
@ -76,7 +76,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1",
|
nodeEndpointFlag: "192.0.2.1",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: defaultGCPStateFile,
|
||||||
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
|
@ -84,56 +84,78 @@ func TestVerify(t *testing.T) {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: func() *state.State {
|
||||||
formatter: &stubAttDocFormatter{},
|
s := defaultGCPStateFile()
|
||||||
wantErr: true,
|
s.Infrastructure.ClusterEndpoint = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
formatter: &stubAttDocFormatter{},
|
||||||
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"endpoint from state file": {
|
"endpoint from state file": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: func() *state.State {
|
||||||
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
s := defaultGCPStateFile()
|
||||||
formatter: &stubAttDocFormatter{},
|
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
|
||||||
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
"override endpoint from details file": {
|
"override endpoint from details file": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
nodeEndpointFlag: "192.0.2.2:1234",
|
nodeEndpointFlag: "192.0.2.2:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
|
stateFile: func() *state.State {
|
||||||
wantEndpoint: "192.0.2.2:1234",
|
s := defaultGCPStateFile()
|
||||||
formatter: &stubAttDocFormatter{},
|
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantEndpoint: "192.0.2.2:1234",
|
||||||
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
"invalid endpoint": {
|
"invalid endpoint": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
nodeEndpointFlag: ":::::",
|
nodeEndpointFlag: ":::::",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: defaultGCPStateFile,
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"neither owner id nor cluster id set": {
|
"neither owner id nor cluster id set": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
stateFile: state.New(),
|
stateFile: func() *state.State {
|
||||||
formatter: &stubAttDocFormatter{},
|
s := defaultGCPStateFile()
|
||||||
wantErr: true,
|
s.ClusterValues.OwnerID = ""
|
||||||
|
s.ClusterValues.ClusterID = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
formatter: &stubAttDocFormatter{},
|
||||||
|
protoClient: &stubVerifyClient{},
|
||||||
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"use owner id from state file": {
|
"use owner id from state file": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: &state.State{ClusterValues: state.ClusterValues{OwnerID: zeroBase64}},
|
stateFile: func() *state.State {
|
||||||
wantEndpoint: "192.0.2.1:1234",
|
s := defaultGCPStateFile()
|
||||||
formatter: &stubAttDocFormatter{},
|
s.ClusterValues.OwnerID = zeroBase64
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantEndpoint: "192.0.2.1:1234",
|
||||||
|
formatter: &stubAttDocFormatter{},
|
||||||
},
|
},
|
||||||
"config file not existing": {
|
"config file not existing": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
stateFile: state.New(),
|
stateFile: defaultGCPStateFile,
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
skipConfigCreation: true,
|
skipConfigCreation: true,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -143,7 +165,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
|
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
|
||||||
stateFile: state.New(),
|
stateFile: defaultAzureStateFile,
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
@ -152,7 +174,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{verifyErr: someErr},
|
protoClient: &stubVerifyClient{verifyErr: someErr},
|
||||||
stateFile: state.New(),
|
stateFile: defaultAzureStateFile,
|
||||||
formatter: &stubAttDocFormatter{},
|
formatter: &stubAttDocFormatter{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
@ -161,7 +183,7 @@ func TestVerify(t *testing.T) {
|
||||||
nodeEndpointFlag: "192.0.2.1:1234",
|
nodeEndpointFlag: "192.0.2.1:1234",
|
||||||
clusterIDFlag: zeroBase64,
|
clusterIDFlag: zeroBase64,
|
||||||
protoClient: &stubVerifyClient{},
|
protoClient: &stubVerifyClient{},
|
||||||
stateFile: state.New(),
|
stateFile: defaultAzureStateFile,
|
||||||
wantEndpoint: "192.0.2.1:1234",
|
wantEndpoint: "192.0.2.1:1234",
|
||||||
formatter: &stubAttDocFormatter{formatErr: someErr},
|
formatter: &stubAttDocFormatter{formatErr: someErr},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -182,7 +204,7 @@ func TestVerify(t *testing.T) {
|
||||||
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
|
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
|
||||||
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
|
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
|
||||||
}
|
}
|
||||||
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
|
require.NoError(tc.stateFile().WriteToFile(fileHandler, constants.StateFilename))
|
||||||
|
|
||||||
v := &verifyCmd{
|
v := &verifyCmd{
|
||||||
fileHandler: fileHandler,
|
fileHandler: fileHandler,
|
||||||
|
|
|
@ -10,7 +10,9 @@ go_library(
|
||||||
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/state",
|
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/state",
|
||||||
visibility = ["//cli:__subpackages__"],
|
visibility = ["//cli:__subpackages__"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//internal/cloud/cloudprovider",
|
||||||
"//internal/file",
|
"//internal/file",
|
||||||
|
"//internal/validation",
|
||||||
"@cat_dario_mergo//:mergo",
|
"@cat_dario_mergo//:mergo",
|
||||||
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
||||||
],
|
],
|
||||||
|
@ -18,9 +20,13 @@ go_library(
|
||||||
|
|
||||||
go_test(
|
go_test(
|
||||||
name = "state_test",
|
name = "state_test",
|
||||||
srcs = ["state_test.go"],
|
srcs = [
|
||||||
|
"state_test.go",
|
||||||
|
"validation_test.go",
|
||||||
|
],
|
||||||
embed = [":state"],
|
embed = [":state"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//internal/cloud/cloudprovider",
|
||||||
"//internal/constants",
|
"//internal/constants",
|
||||||
"//internal/file",
|
"//internal/file",
|
||||||
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
|
||||||
|
|
|
@ -19,7 +19,9 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
|
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||||
|
"github.com/edgelesssys/constellation/v2/internal/validation"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -27,25 +29,44 @@ const (
|
||||||
Version1 = "v1"
|
Version1 = "v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadFromFile reads the state file at the given path and returns the state.
|
const (
|
||||||
|
// PreCreate are the constraints that should be enforced when the state file
|
||||||
|
// is validated before cloud infrastructure is created.
|
||||||
|
PreCreate ConstraintSet = iota
|
||||||
|
// PreInit are the constraints that should be enforced when the state file
|
||||||
|
// is validated before the first Constellation node is initialized.
|
||||||
|
PreInit
|
||||||
|
// PostInit are the constraints that should be enforced when the state file
|
||||||
|
// is validated after the cluster was initialized.
|
||||||
|
PostInit
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConstraintSet defines which constraints the state file
|
||||||
|
// should be validated against.
|
||||||
|
type ConstraintSet int
|
||||||
|
|
||||||
|
// ReadFromFile reads the state file at the given path and validates it.
|
||||||
|
// If the state file is valid, the state is returned. Otherwise, an error
|
||||||
|
// describing why the validation failed is returned.
|
||||||
func ReadFromFile(fileHandler file.Handler, path string) (*State, error) {
|
func ReadFromFile(fileHandler file.Handler, path string) (*State, error) {
|
||||||
state := &State{}
|
state := &State{}
|
||||||
if err := fileHandler.ReadYAML(path, &state); err != nil {
|
if err := fileHandler.ReadYAML(path, &state); err != nil {
|
||||||
return nil, fmt.Errorf("reading state file: %w", err)
|
return nil, fmt.Errorf("reading state file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOrRead reads the state file at the given path, if it exists, and returns the state.
|
// CreateOrRead reads the state file at the given path, if it exists, and returns the state.
|
||||||
// If the file does not exist, a new state is created and written to disk.
|
// If the file does not exist, a new state is created and written to disk.
|
||||||
func CreateOrRead(fileHandler file.Handler, path string) (*State, error) {
|
func CreateOrRead(fileHandler file.Handler, path string) (*State, error) {
|
||||||
state := &State{}
|
state, err := ReadFromFile(fileHandler, path)
|
||||||
if err := fileHandler.ReadYAML(path, &state); err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("reading state file: %w", err)
|
return nil, fmt.Errorf("reading state file: %w", err)
|
||||||
}
|
}
|
||||||
state = New()
|
newState := New()
|
||||||
return state, state.WriteToFile(fileHandler, path)
|
return newState, newState.WriteToFile(fileHandler, path)
|
||||||
}
|
}
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
@ -186,6 +207,349 @@ func (s *State) Merge(other *State) (*State, error) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Validate validates the state against the given constraint set and CSP, which can be one of
|
||||||
|
- PreCreate, which is the constraint set that should be enforced before "constellation create" is run.
|
||||||
|
- PreInit, which is the constraint set that should be enforced before "constellation apply" is run.
|
||||||
|
- PostInit, which is the constraint set that should be enforced after "constellation apply" is run.
|
||||||
|
*/
|
||||||
|
func (s *State) Validate(constraintSet ConstraintSet, csp cloudprovider.Provider) error {
|
||||||
|
v := validation.NewValidator()
|
||||||
|
|
||||||
|
switch constraintSet {
|
||||||
|
case PreCreate:
|
||||||
|
return v.Validate(s, validation.ValidateOptions{
|
||||||
|
OverrideConstraints: s.preCreateConstraints,
|
||||||
|
})
|
||||||
|
case PreInit:
|
||||||
|
return v.Validate(s, validation.ValidateOptions{
|
||||||
|
OverrideConstraints: s.preInitConstraints,
|
||||||
|
})
|
||||||
|
case PostInit:
|
||||||
|
return v.Validate(s, validation.ValidateOptions{
|
||||||
|
OverrideConstraints: s.postInitConstraints(csp),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
return errors.New("unknown constraint set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preCreateConstraints are the constraints on the state that should be enforced
|
||||||
|
// before a Constellation cluster is created.
|
||||||
|
//
|
||||||
|
// The constraints check if the state file version is valid,
|
||||||
|
// and if all fields are empty, which is a requirement pre-create.
|
||||||
|
func (s *State) preCreateConstraints() []*validation.Constraint {
|
||||||
|
return []*validation.Constraint{
|
||||||
|
// state version needs to be accepted by the parsing CLI.
|
||||||
|
validation.OneOf(s.Version, []string{Version1}).
|
||||||
|
WithFieldTrace(s, &s.Version),
|
||||||
|
// Infrastructure must be empty.
|
||||||
|
// As the infrastructure struct contains slices, we cannot use the
|
||||||
|
// Empty constraint on the entire struct. Instead, we need to check
|
||||||
|
// each field individually.
|
||||||
|
validation.Empty(s.Infrastructure.UID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||||
|
validation.Empty(s.Infrastructure.ClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||||
|
validation.Empty(s.Infrastructure.InClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||||
|
validation.Empty(s.Infrastructure.Name).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||||
|
validation.Empty(s.Infrastructure.IPCidrNode).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||||
|
validation.EmptySlice(s.Infrastructure.APIServerCertSANs).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.APIServerCertSANs),
|
||||||
|
validation.EmptySlice(s.Infrastructure.InitSecret).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InitSecret),
|
||||||
|
// ClusterValues must be empty.
|
||||||
|
// As the clusterValues struct contains slices, we cannot use the
|
||||||
|
// Empty constraint on the entire struct. Instead, we need to check
|
||||||
|
// each field individually.
|
||||||
|
validation.Empty(s.ClusterValues.ClusterID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||||
|
validation.Empty(s.ClusterValues.OwnerID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||||
|
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preInitConstraints are the constraints on the state that should be enforced
|
||||||
|
// *before* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
|
||||||
|
//
|
||||||
|
// The constraints check if the infrastructure state is valid, and if the cluster values
|
||||||
|
// are empty, which is required for the cluster to initialize correctly.
|
||||||
|
func (s *State) preInitConstraints() []*validation.Constraint {
|
||||||
|
return []*validation.Constraint{
|
||||||
|
// state version needs to be accepted by the parsing CLI.
|
||||||
|
validation.OneOf(s.Version, []string{Version1}).
|
||||||
|
WithFieldTrace(s, &s.Version),
|
||||||
|
// infrastructure must be valid.
|
||||||
|
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
|
||||||
|
validation.Or(
|
||||||
|
validation.DNSName(s.Infrastructure.ClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||||
|
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||||
|
),
|
||||||
|
// in-cluster endpoint needs to be a valid DNS name or IP address.
|
||||||
|
validation.Or(
|
||||||
|
validation.DNSName(s.Infrastructure.InClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||||
|
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||||
|
),
|
||||||
|
// Node IP Cidr needs to be a valid CIDR range.
|
||||||
|
validation.CIDR(s.Infrastructure.IPCidrNode).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||||
|
// UID needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.UID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||||
|
// Name needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.Name).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||||
|
// GCP values need to be nil, empty, or valid.
|
||||||
|
validation.Or(
|
||||||
|
validation.Or(
|
||||||
|
// nil.
|
||||||
|
validation.Equal(s.Infrastructure.GCP, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||||
|
// empty.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.GCP,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.Empty(*s.Infrastructure.GCP).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
// valid.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.GCP,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.And(
|
||||||
|
validation.EvaluateAll,
|
||||||
|
// ProjectID needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
|
||||||
|
// Pod IP Cidr needs to be a valid CIDR range.
|
||||||
|
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
// Azure values need to be nil, empty, or valid.
|
||||||
|
validation.Or(
|
||||||
|
validation.Or(
|
||||||
|
// nil.
|
||||||
|
validation.Equal(s.Infrastructure.Azure, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||||
|
// empty.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.Azure,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.And(
|
||||||
|
validation.EvaluateAll,
|
||||||
|
validation.Empty(s.Infrastructure.Azure.ResourceGroup).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||||
|
validation.Empty(s.Infrastructure.Azure.SubscriptionID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||||
|
validation.Empty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||||
|
validation.Empty(s.Infrastructure.Azure.LoadBalancerName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||||
|
validation.Empty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||||
|
validation.Empty(s.Infrastructure.Azure.AttestationURL).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
// valid.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.Azure,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.And(
|
||||||
|
validation.EvaluateAll,
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
// ClusterValues must be empty.
|
||||||
|
// As the clusterValues struct contains slices, we cannot use the
|
||||||
|
// Empty constraint on the entire struct. Instead, we need to check
|
||||||
|
// each field individually.
|
||||||
|
validation.Empty(s.ClusterValues.ClusterID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||||
|
validation.Empty(s.ClusterValues.OwnerID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||||
|
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postInitConstraints are the constraints on the state that should be enforced
|
||||||
|
// *after* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
|
||||||
|
//
|
||||||
|
// The constraints check if the infrastructure state and cluster state
|
||||||
|
// is valid, so that the cluster can be used correctly.
|
||||||
|
func (s *State) postInitConstraints(csp cloudprovider.Provider) func() []*validation.Constraint {
|
||||||
|
return func() []*validation.Constraint {
|
||||||
|
constraints := []*validation.Constraint{
|
||||||
|
// state version needs to be accepted by the parsing CLI.
|
||||||
|
validation.OneOf(s.Version, []string{Version1}).
|
||||||
|
WithFieldTrace(s, &s.Version),
|
||||||
|
// infrastructure must be valid.
|
||||||
|
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
|
||||||
|
validation.Or(
|
||||||
|
validation.DNSName(s.Infrastructure.ClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||||
|
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
|
||||||
|
),
|
||||||
|
// in-cluster endpoint needs to be a valid DNS name or IP address.
|
||||||
|
validation.Or(
|
||||||
|
validation.DNSName(s.Infrastructure.InClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||||
|
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
|
||||||
|
),
|
||||||
|
// Node IP Cidr needs to be a valid CIDR range.
|
||||||
|
validation.CIDR(s.Infrastructure.IPCidrNode).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
|
||||||
|
// UID needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.UID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.UID),
|
||||||
|
// Name needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.Name).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Name),
|
||||||
|
// ClusterValues need to be valid.
|
||||||
|
// ClusterID needs to be filled.
|
||||||
|
validation.NotEmpty(s.ClusterValues.ClusterID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.ClusterID),
|
||||||
|
// OwnerID needs to be filled.
|
||||||
|
validation.NotEmpty(s.ClusterValues.OwnerID).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.OwnerID),
|
||||||
|
// MeasurementSalt needs to be filled.
|
||||||
|
validation.NotEmptySlice(s.ClusterValues.MeasurementSalt).
|
||||||
|
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch csp {
|
||||||
|
case cloudprovider.Azure:
|
||||||
|
constraints = append(constraints,
|
||||||
|
// GCP values need to be nil or empty.
|
||||||
|
validation.Or(
|
||||||
|
validation.Equal(s.Infrastructure.GCP, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.GCP,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.Empty(s.Infrastructure.GCP).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
// Azure values need to be valid.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.Azure,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.And(
|
||||||
|
validation.EvaluateAll,
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
|
||||||
|
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case cloudprovider.GCP:
|
||||||
|
constraints = append(constraints,
|
||||||
|
// Azure values need to be nil or empty.
|
||||||
|
validation.Or(
|
||||||
|
validation.Equal(s.Infrastructure.Azure, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.Azure,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.Empty(s.Infrastructure.Azure).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure)
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
// GCP values need to be valid.
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.GCP,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.And(
|
||||||
|
validation.EvaluateAll,
|
||||||
|
// ProjectID needs to be filled.
|
||||||
|
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
|
||||||
|
// Pod IP Cidr needs to be a valid CIDR range.
|
||||||
|
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
constraints = append(constraints,
|
||||||
|
// GCP values need to be nil or empty.
|
||||||
|
validation.Or(
|
||||||
|
validation.Equal(s.Infrastructure.GCP, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP),
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.GCP,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.Empty(s.Infrastructure.GCP).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.GCP)
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
// Azure values need to be nil or empty.
|
||||||
|
validation.Or(
|
||||||
|
validation.Equal(s.Infrastructure.Azure, nil).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure),
|
||||||
|
validation.IfNotNil(
|
||||||
|
s.Infrastructure.Azure,
|
||||||
|
func() *validation.Constraint {
|
||||||
|
return validation.Empty(s.Infrastructure.Azure).
|
||||||
|
WithFieldTrace(s, &s.Infrastructure.Azure)
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return constraints
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constraints is a no-op implementation to fulfill the "Validatable" interface.
|
||||||
|
func (s *State) Constraints() []*validation.Constraint {
|
||||||
|
return []*validation.Constraint{}
|
||||||
|
}
|
||||||
|
|
||||||
// HexBytes is a byte slice that is marshalled to and from a hex string.
|
// HexBytes is a byte slice that is marshalled to and from a hex string.
|
||||||
type HexBytes []byte
|
type HexBytes []byte
|
||||||
|
|
||||||
|
|
|
@ -18,18 +18,21 @@ import (
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// defaultState returns a valid default state for testing.
|
||||||
func defaultState() *State {
|
func defaultState() *State {
|
||||||
return &State{
|
return &State{
|
||||||
Version: "v1",
|
Version: "v1",
|
||||||
Infrastructure: Infrastructure{
|
Infrastructure: Infrastructure{
|
||||||
UID: "123",
|
UID: "123",
|
||||||
ClusterEndpoint: "test-cluster-endpoint",
|
Name: "test-cluster",
|
||||||
InitSecret: []byte{0x41},
|
ClusterEndpoint: "0.0.0.0",
|
||||||
|
InClusterEndpoint: "0.0.0.0",
|
||||||
|
InitSecret: []byte{0x41},
|
||||||
APIServerCertSANs: []string{
|
APIServerCertSANs: []string{
|
||||||
"api-server-cert-san-test",
|
"127.0.0.1",
|
||||||
"api-server-cert-san-test-2",
|
"www.example.com",
|
||||||
},
|
},
|
||||||
IPCidrNode: "test-cidr-node",
|
IPCidrNode: "0.0.0.0/24",
|
||||||
Azure: &Azure{
|
Azure: &Azure{
|
||||||
ResourceGroup: "test-rg",
|
ResourceGroup: "test-rg",
|
||||||
SubscriptionID: "test-sub",
|
SubscriptionID: "test-sub",
|
||||||
|
@ -40,7 +43,7 @@ func defaultState() *State {
|
||||||
},
|
},
|
||||||
GCP: &GCP{
|
GCP: &GCP{
|
||||||
ProjectID: "test-project",
|
ProjectID: "test-project",
|
||||||
IPCidrPod: "test-cidr-pod",
|
IPCidrPod: "0.0.0.0/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ClusterValues: ClusterValues{
|
ClusterValues: ClusterValues{
|
||||||
|
@ -51,6 +54,18 @@ func defaultState() *State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func defaultAzureState() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.GCP = nil
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultGCPState() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.Azure = nil
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteToFile(t *testing.T) {
|
func TestWriteToFile(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
state *State
|
state *State
|
||||||
|
|
379
cli/internal/state/validation_test.go
Normal file
379
cli/internal/state/validation_test.go
Normal file
|
@ -0,0 +1,379 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreCreateValidation(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
stateFile func() *State
|
||||||
|
wantErr bool
|
||||||
|
errAssertions func(a *assert.Assertions, err error)
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
return &State{
|
||||||
|
Version: Version1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"invalid version": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
return &State{
|
||||||
|
Version: "invalid",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"infrastructure not empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
return &State{
|
||||||
|
Version: Version1,
|
||||||
|
Infrastructure: Infrastructure{
|
||||||
|
ClusterEndpoint: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: test must be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cluster values not empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
return &State{
|
||||||
|
Version: Version1,
|
||||||
|
ClusterValues: ClusterValues{
|
||||||
|
ClusterID: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.clusterValues.clusterID: test must be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := tc.stateFile().Validate(PreCreate, cloudprovider.Azure)
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tc.errAssertions != nil {
|
||||||
|
tc.errAssertions(assert.New(t), err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreInitValidation(t *testing.T) {
|
||||||
|
validPreInitState := func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.ClusterValues = ClusterValues{}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
stateFile func() *State
|
||||||
|
wantErr bool
|
||||||
|
errAssertions func(a *assert.Assertions, err error)
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
stateFile: validPreInitState,
|
||||||
|
},
|
||||||
|
"invalid version": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Version = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cluster endpoint invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.ClusterEndpoint = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"in-cluster endpoint invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.InClusterEndpoint = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"node ip cidr invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.IPCidrNode = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"uid empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.UID = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"name empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.Name = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gcp empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.GCP = &GCP{}
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gcp nil": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.GCP = nil
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gcp invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.GCP.IPCidrPod = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.gcp.ipCidrPod: invalid must be a valid CIDR")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"azure empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.Azure = &Azure{}
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"azure nil": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.Azure = nil
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"azure invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := validPreInitState()
|
||||||
|
s.Infrastructure.Azure.NetworkSecurityGroupName = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.azure.networkSecurityGroupName: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := tc.stateFile().Validate(PreInit, cloudprovider.Azure)
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tc.errAssertions != nil {
|
||||||
|
tc.errAssertions(assert.New(t), err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostInitValidation(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
stateFile func() *State
|
||||||
|
provider cloudprovider.Provider
|
||||||
|
wantErr bool
|
||||||
|
errAssertions func(a *assert.Assertions, err error)
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
stateFile: defaultGCPState,
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
},
|
||||||
|
"invalid version": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Version = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cluster endpoint invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.ClusterEndpoint = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"in-cluster endpoint invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.InClusterEndpoint = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"node ip cidr invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.IPCidrNode = "invalid"
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"uid empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.UID = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"name empty": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.Infrastructure.Name = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gcp valid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultGCPState()
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
},
|
||||||
|
"azure valid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultAzureState()
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
},
|
||||||
|
"gcp, azure not nil": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "must be equal to <nil>")
|
||||||
|
a.Contains(err.Error(), "must be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"azure, gcp not nil": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "must be equal to <nil>")
|
||||||
|
a.Contains(err.Error(), "must be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cluster values invalid": {
|
||||||
|
stateFile: func() *State {
|
||||||
|
s := defaultState()
|
||||||
|
s.ClusterValues.ClusterID = ""
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertions: func(a *assert.Assertions, err error) {
|
||||||
|
a.Contains(err.Error(), "validating State.clusterValues.clusterID: must not be empty")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := tc.stateFile().Validate(PostInit, tc.provider)
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tc.errAssertions != nil {
|
||||||
|
tc.errAssertions(assert.New(t), err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ go_library(
|
||||||
go_test(
|
go_test(
|
||||||
name = "validation_test",
|
name = "validation_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"constraints_test.go",
|
||||||
"errors_test.go",
|
"errors_test.go",
|
||||||
"validation_test.go",
|
"validation_test.go",
|
||||||
],
|
],
|
||||||
|
|
|
@ -8,6 +8,7 @@ package validation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
@ -15,8 +16,10 @@ import (
|
||||||
// Constraint is a constraint on a document or a field of a document.
|
// Constraint is a constraint on a document or a field of a document.
|
||||||
type Constraint struct {
|
type Constraint struct {
|
||||||
// Satisfied returns no error if the constraint is satisfied.
|
// Satisfied returns no error if the constraint is satisfied.
|
||||||
// Otherwise, it returns the reason why the constraint is not satisfied.
|
// Otherwise, it returns the reason why the constraint is not satisfied,
|
||||||
Satisfied func() error
|
// possibly including its child errors, i.e., errors returned by constraints
|
||||||
|
// that are embedded in this constraint.
|
||||||
|
Satisfied func() *TreeError
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -36,7 +39,7 @@ Example for a pointer field:
|
||||||
Due to Go's addressability limititations regarding maps, if a map field is
|
Due to Go's addressability limititations regarding maps, if a map field is
|
||||||
to be validated, WithMapFieldTrace must be used instead of WithFieldTrace.
|
to be validated, WithMapFieldTrace must be used instead of WithFieldTrace.
|
||||||
*/
|
*/
|
||||||
func (c *Constraint) WithFieldTrace(doc any, field any) Constraint {
|
func (c *Constraint) WithFieldTrace(doc any, field any) *Constraint {
|
||||||
// we only want to dereference the needle once to dereference the pointer
|
// we only want to dereference the needle once to dereference the pointer
|
||||||
// used to pass it to the function without losing reference to it, as the
|
// used to pass it to the function without losing reference to it, as the
|
||||||
// needle could be an arbitrarily long chain of pointers. The same
|
// needle could be an arbitrarily long chain of pointers. The same
|
||||||
|
@ -69,7 +72,7 @@ Example:
|
||||||
|
|
||||||
For non-map fields, WithFieldTrace should be used instead of WithMapFieldTrace.
|
For non-map fields, WithFieldTrace should be used instead of WithMapFieldTrace.
|
||||||
*/
|
*/
|
||||||
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constraint {
|
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) *Constraint {
|
||||||
// we only want to dereference the needle once to dereference the pointer
|
// we only want to dereference the needle once to dereference the pointer
|
||||||
// used to pass it to the function without losing reference to it, as the
|
// used to pass it to the function without losing reference to it, as the
|
||||||
// needle could be an arbitrarily long chain of pointers. The same
|
// needle could be an arbitrarily long chain of pointers. The same
|
||||||
|
@ -91,11 +94,11 @@ func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constr
|
||||||
}
|
}
|
||||||
|
|
||||||
// withTrace wraps the constraint's error message with a well-formatted trace.
|
// withTrace wraps the constraint's error message with a well-formatted trace.
|
||||||
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
|
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) *Constraint {
|
||||||
return Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
if err := c.Satisfied(); err != nil {
|
if err := c.Satisfied(); err != nil {
|
||||||
return newError(docRef, fieldRef, err)
|
return newTraceError(docRef, fieldRef, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -105,49 +108,207 @@ func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
|
||||||
// MatchRegex is a constraint that if s matches regex.
|
// MatchRegex is a constraint that if s matches regex.
|
||||||
func MatchRegex(s string, regex string) *Constraint {
|
func MatchRegex(s string, regex string) *Constraint {
|
||||||
return &Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
if !regexp.MustCompile(regex).MatchString(s) {
|
if !regexp.MustCompile(regex).MatchString(s) {
|
||||||
return fmt.Errorf("%s must match the pattern %s", s, regex)
|
return NewErrorTree(fmt.Errorf("%s must match the pattern %s", s, regex))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equal is a constraint that if s is equal to t.
|
// Equal is a constraint that checks if s is equal to t.
|
||||||
func Equal[T comparable](s T, t T) *Constraint {
|
func Equal[T comparable](s T, t T) *Constraint {
|
||||||
return &Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
if s != t {
|
if s != t {
|
||||||
return fmt.Errorf("%v must be equal to %v", s, t)
|
return NewErrorTree(fmt.Errorf("%v must be equal to %v", s, t))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotEmpty is a constraint that if s is not empty.
|
// NotEqual is a constraint that checks if s is not equal to t.
|
||||||
func NotEmpty[T comparable](s T) *Constraint {
|
func NotEqual[T comparable](s T, t T) *Constraint {
|
||||||
return &Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
var zero T
|
if Equal(s, t).Satisfied() == nil {
|
||||||
if s == zero {
|
return NewErrorTree(fmt.Errorf("%v must not be equal to %v", s, t))
|
||||||
return fmt.Errorf("%v must not be empty", s)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty is a constraint that if s is empty.
|
// Empty is a constraint that checks if s is empty.
|
||||||
func Empty[T comparable](s T) *Constraint {
|
func Empty[T comparable](s T) *Constraint {
|
||||||
return &Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
var zero T
|
var zero T
|
||||||
if s != zero {
|
if s != zero {
|
||||||
return fmt.Errorf("%v must be empty", s)
|
return NewErrorTree(fmt.Errorf("%v must be empty", s))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NotEmpty is a constraint that checks if s is not empty.
|
||||||
|
func NotEmpty[T comparable](s T) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if Empty(s).Satisfied() == nil {
|
||||||
|
return NewErrorTree(fmt.Errorf("must not be empty"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OneOf is a constraint that s is in the set of values p.
|
||||||
|
func OneOf[T comparable](s T, p []T) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
for _, v := range p {
|
||||||
|
if s == v {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NewErrorTree(fmt.Errorf("%v must be one of %v", s, p))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddress is a constraint that checks if s is a valid IP address.
|
||||||
|
func IPAddress(s string) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if net.ParseIP(s) == nil {
|
||||||
|
return NewErrorTree(fmt.Errorf("%s must be a valid IP address", s))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CIDR is a constraint that checks if s is a valid CIDR.
|
||||||
|
func CIDR(s string) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if _, _, err := net.ParseCIDR(s); err != nil {
|
||||||
|
return NewErrorTree(fmt.Errorf("%s must be a valid CIDR", s))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSName is a constraint that checks if s is a valid DNS name.
|
||||||
|
func DNSName(s string) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if _, err := net.LookupHost(s); err != nil {
|
||||||
|
return NewErrorTree(fmt.Errorf("%s must be a valid DNS name", s))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptySlice is a constraint that checks if s is an empty slice.
|
||||||
|
func EmptySlice[T comparable](s []T) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if len(s) != 0 {
|
||||||
|
return NewErrorTree(fmt.Errorf("%v must be empty", s))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotEmptySlice is a constraint that checks if slice s is not empty.
|
||||||
|
func NotEmptySlice[T comparable](s []T) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if EmptySlice(s).Satisfied() == nil {
|
||||||
|
return NewErrorTree(fmt.Errorf("must not be empty"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All is a constraint that checks if all elements of s satisfy the constraint c.
|
||||||
|
// The constraint should be parametric in regards to the index of the element in s,
|
||||||
|
// as well as the element itself.
|
||||||
|
func All[T comparable](s []T, c func(i int, v T) *Constraint) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
|
||||||
|
for i, v := range s {
|
||||||
|
if err := c(i, v).Satisfied(); err != nil {
|
||||||
|
retErr.appendChild(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(retErr.children) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return retErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// And groups multiple constraints in an "and" relation and fails according to the given strategy.
|
||||||
|
func And(errStrat ErrStrategy, constraints ...*Constraint) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
|
||||||
|
for _, constraint := range constraints {
|
||||||
|
if err := constraint.Satisfied(); err != nil {
|
||||||
|
if errStrat == FailFast {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
retErr.appendChild(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(retErr.children) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return retErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or groups multiple constraints in an "or" relation.
|
||||||
|
func Or(constraints ...*Constraint) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
retErr := NewErrorTree(fmt.Errorf("at least one of the constraints must be satisfied: "))
|
||||||
|
for _, constraint := range constraints {
|
||||||
|
err := constraint.Satisfied()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
retErr.appendChild(err)
|
||||||
|
}
|
||||||
|
if len(retErr.children) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return retErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IfNotNil evaluates a constraint if and only if s is not nil.
|
||||||
|
func IfNotNil[T comparable](s *T, c func() *Constraint) *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c().Satisfied()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
290
internal/validation/constraints_test.go
Normal file
290
internal/validation/constraints_test.go
Normal file
|
@ -0,0 +1,290 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPAddress(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ip string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid ipv4": {
|
||||||
|
ip: "127.0.0.1",
|
||||||
|
},
|
||||||
|
"valid ipv6": {
|
||||||
|
ip: "2001:db8::68",
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
ip: "invalid",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"empty": {
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := IPAddress(tc.ip).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCIDR(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
cidr string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid ipv4": {
|
||||||
|
cidr: "192.0.2.0/24",
|
||||||
|
},
|
||||||
|
"valid ipv6": {
|
||||||
|
cidr: "2001:db8::/32",
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
cidr: "invalid",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"empty": {
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := CIDR(tc.cidr).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSName(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
dnsName string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
dnsName: "example.com",
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
dnsName: "invalid",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"empty": {
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := DNSName(tc.dnsName).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptySlice(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
s []any
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
s: []any{},
|
||||||
|
},
|
||||||
|
"nil": {
|
||||||
|
s: nil,
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
s: []any{1},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := EmptySlice(tc.s).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotEmptySlice(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
s []any
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
s: []any{1},
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
s: []any{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"nil": {
|
||||||
|
s: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := NotEmptySlice(tc.s).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAll(t *testing.T) {
|
||||||
|
c := func(i int, s string) *Constraint {
|
||||||
|
return Equal(s, "abc")
|
||||||
|
}
|
||||||
|
testCases := map[string]struct {
|
||||||
|
s []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
s: []string{"abc", "abc", "abc"},
|
||||||
|
},
|
||||||
|
"nil": {
|
||||||
|
s: nil,
|
||||||
|
},
|
||||||
|
"empty": {
|
||||||
|
s: []string{},
|
||||||
|
},
|
||||||
|
"all are invalid": {
|
||||||
|
s: []string{"def", "lol"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"one is invalid": {
|
||||||
|
s: []string{"abc", "def"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := All(tc.s, c).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotEqual(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
a any
|
||||||
|
b any
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
a: "abc",
|
||||||
|
b: "def",
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
a: "abc",
|
||||||
|
b: "abc",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"empty": {
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"one empty": {
|
||||||
|
a: "abc",
|
||||||
|
b: "",
|
||||||
|
},
|
||||||
|
"one nil": {
|
||||||
|
a: "abc",
|
||||||
|
b: nil,
|
||||||
|
},
|
||||||
|
"both nil": {
|
||||||
|
a: nil,
|
||||||
|
b: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := NotEqual(tc.a, tc.b).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIfNotNil(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
a *int
|
||||||
|
c func() *Constraint
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid": {
|
||||||
|
a: new(int),
|
||||||
|
c: func() *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"nil": {
|
||||||
|
a: nil,
|
||||||
|
c: func() *Constraint {
|
||||||
|
return &Constraint{
|
||||||
|
Satisfied: func() *TreeError {
|
||||||
|
t.Fatal("should not be called")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
err := IfNotNil(tc.a, tc.c).Satisfied()
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,42 +13,86 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error is returned when a document is not valid.
|
// TreeError is returned when a document is not valid.
|
||||||
type Error struct {
|
// It contains the path to the field that failed validation, the error
|
||||||
Path string
|
// that occurred, as well as a list of child errors, as one constraint
|
||||||
Err error
|
// can embed multiple other constraints, e.g. in an OR.
|
||||||
|
type TreeError struct {
|
||||||
|
path string
|
||||||
|
err error
|
||||||
|
children []*TreeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorTree creates a new error tree from the given error.
|
||||||
|
func NewErrorTree(err error) *TreeError {
|
||||||
|
return &TreeError{
|
||||||
|
err: err,
|
||||||
|
children: []*TreeError{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
newError creates a new validation Error.
|
newTraceError creates a new validation error, traced to a field.
|
||||||
|
|
||||||
To find the path to the exported field that failed validation, it traverses "doc"
|
To find the path to the exported field that failed validation, it traverses "doc"
|
||||||
recursively until it finds a field in "doc" that matches the reference to "field".
|
recursively until it finds a field in "doc" that matches the reference to "field".
|
||||||
*/
|
*/
|
||||||
func newError(doc, field referenceableValue, errMsg error) *Error {
|
func newTraceError(doc, field referenceableValue, errMsg error) *TreeError {
|
||||||
// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found
|
// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found
|
||||||
path, err := traverse(doc, field, newPathBuilder(doc._type.Name()))
|
path, err := traverse(doc, field, newPathBuilder(doc._type.Name()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &Error{
|
return &TreeError{
|
||||||
Path: "unknown",
|
path: "unknown",
|
||||||
Err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
|
err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Error{
|
return &TreeError{
|
||||||
Path: path,
|
path: path,
|
||||||
Err: errMsg,
|
err: errMsg,
|
||||||
|
children: []*TreeError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error implements the error interface.
|
// Error implements the error interface.
|
||||||
func (e *Error) Error() string {
|
func (e *TreeError) Error() string {
|
||||||
return fmt.Sprintf("validating %s: %s", e.Path, e.Err)
|
return e.format(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unwrap implements the error interface.
|
// Unwrap implements the error interface.
|
||||||
func (e *Error) Unwrap() error {
|
func (e *TreeError) Unwrap() error {
|
||||||
return e.Err
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// format formats the error tree and all of its children.
|
||||||
|
func (e *TreeError) format(indent int) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if e.path != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf(
|
||||||
|
"%svalidating %s: %s",
|
||||||
|
strings.Repeat(" ", indent),
|
||||||
|
e.path,
|
||||||
|
e.err,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
sb.WriteString(fmt.Sprintf(
|
||||||
|
"%s%s",
|
||||||
|
strings.Repeat(" ", indent),
|
||||||
|
e.err,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
for _, child := range e.children {
|
||||||
|
sb.WriteString(fmt.Sprintf(
|
||||||
|
"\n%s",
|
||||||
|
child.format(indent+1),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendChild adds the given child error to the tree.
|
||||||
|
func (e *TreeError) appendChild(child *TreeError) {
|
||||||
|
e.children = append(e.children, child)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -238,9 +282,13 @@ func newPathBuilder(topLevelDoc string) pathBuilder {
|
||||||
func (p pathBuilder) appendStructField(field reflect.StructField) pathBuilder {
|
func (p pathBuilder) appendStructField(field reflect.StructField) pathBuilder {
|
||||||
switch {
|
switch {
|
||||||
case field.Tag.Get("json") != "":
|
case field.Tag.Get("json") != "":
|
||||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("json")))
|
// cut off omitempty or other options
|
||||||
|
jsonTagName, _, _ := strings.Cut(field.Tag.Get("json"), ",")
|
||||||
|
p.buf = append(p.buf, fmt.Sprintf(".%s", jsonTagName))
|
||||||
case field.Tag.Get("yaml") != "":
|
case field.Tag.Get("yaml") != "":
|
||||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("yaml")))
|
// cut off omitempty or other options
|
||||||
|
yamlTagName, _, _ := strings.Cut(field.Tag.Get("yaml"), ",")
|
||||||
|
p.buf = append(p.buf, fmt.Sprintf(".%s", yamlTagName))
|
||||||
default:
|
default:
|
||||||
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Name))
|
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Name))
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,37 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestErrorFormatting(t *testing.T) {
|
||||||
|
err := &TreeError{
|
||||||
|
path: "path",
|
||||||
|
err: fmt.Errorf("error"),
|
||||||
|
children: []*TreeError{},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "validating path: error", err.Error())
|
||||||
|
|
||||||
|
err.children = append(err.children, &TreeError{
|
||||||
|
path: "child",
|
||||||
|
err: fmt.Errorf("child error"),
|
||||||
|
children: []*TreeError{},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, "validating path: error\n validating child: child error", err.Error())
|
||||||
|
|
||||||
|
err.children = append(err.children, &TreeError{
|
||||||
|
path: "child2",
|
||||||
|
err: fmt.Errorf("child2 error"),
|
||||||
|
children: []*TreeError{
|
||||||
|
{
|
||||||
|
path: "child2child",
|
||||||
|
err: fmt.Errorf("child2child error"),
|
||||||
|
children: []*TreeError{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.Equal(t, "validating path: error\n validating child: child error\n validating child2: child2 error\n validating child2child: child2child error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
// Tests for primitive / shallow fields
|
// Tests for primitive / shallow fields
|
||||||
|
|
||||||
func TestNewValidationErrorSingleField(t *testing.T) {
|
func TestNewValidationErrorSingleField(t *testing.T) {
|
||||||
|
@ -24,7 +55,7 @@ func TestNewValidationErrorSingleField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.OtherField, "")
|
doc, field := references(t, st, &st.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.otherField: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
@ -37,7 +68,7 @@ func TestNewValidationErrorSingleFieldPtr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.PointerField, "")
|
doc, field := references(t, st, &st.PointerField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.pointerField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.pointerField: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
@ -51,7 +82,7 @@ func TestNewValidationErrorSingleFieldDoublePtr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.DoublePointerField, "")
|
doc, field := references(t, st, &st.DoublePointerField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.doublePointerField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.doublePointerField: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
@ -66,7 +97,7 @@ func TestNewValidationErrorSingleFieldInexistent(t *testing.T) {
|
||||||
inexistentField := 123
|
inexistentField := 123
|
||||||
|
|
||||||
doc, field := references(t, st, &inexistentField, "")
|
doc, field := references(t, st, &inexistentField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "cannot find path to field: cannot traverse anymore")
|
require.Contains(t, err.Error(), "cannot find path to field: cannot traverse anymore")
|
||||||
}
|
}
|
||||||
|
@ -84,7 +115,7 @@ func TestNewValidationErrorNestedField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedField.OtherField, "")
|
doc, field := references(t, st, &st.NestedField.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.otherField: %s", assert.AnError))
|
||||||
|
@ -102,7 +133,7 @@ func TestNewValidationErrorPointerInNestedField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedField.PointerField, "")
|
doc, field := references(t, st, &st.NestedField.PointerField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.pointerField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.pointerField: %s", assert.AnError))
|
||||||
|
@ -123,7 +154,7 @@ func TestNewValidationErrorNestedFieldPtr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedPointerField.OtherField, "")
|
doc, field := references(t, st, &st.NestedPointerField.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.otherField: %s", assert.AnError))
|
||||||
|
@ -144,7 +175,7 @@ func TestNewValidationErrorNestedNestedField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedField.NestedField.OtherField, "")
|
doc, field := references(t, st, &st.NestedField.NestedField.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedField.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedField.otherField: %s", assert.AnError))
|
||||||
|
@ -165,7 +196,7 @@ func TestNewValidationErrorNestedNestedFieldPtr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedField.NestedPointerField.OtherField, "")
|
doc, field := references(t, st, &st.NestedField.NestedPointerField.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedPointerField.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedPointerField.otherField: %s", assert.AnError))
|
||||||
|
@ -186,7 +217,7 @@ func TestNewValidationErrorNestedPtrNestedFieldPtr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NestedPointerField.NestedPointerField.OtherField, "")
|
doc, field := references(t, st, &st.NestedPointerField.NestedPointerField.OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.nestedPointerField.otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.nestedPointerField.otherField: %s", assert.AnError))
|
||||||
|
@ -200,7 +231,7 @@ func TestNewValidationErrorPrimitiveSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.PrimitiveSlice[1], "")
|
doc, field := references(t, st, &st.PrimitiveSlice[1], "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSlice[1]: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSlice[1]: %s", assert.AnError))
|
||||||
|
@ -212,7 +243,7 @@ func TestNewValidationErrorPrimitiveArray(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.PrimitiveArray[1], "")
|
doc, field := references(t, st, &st.PrimitiveArray[1], "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveArray[1]: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveArray[1]: %s", assert.AnError))
|
||||||
|
@ -233,7 +264,7 @@ func TestNewValidationErrorStructSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.StructSlice[1].OtherField, "")
|
doc, field := references(t, st, &st.StructSlice[1].OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structSlice[1].otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structSlice[1].otherField: %s", assert.AnError))
|
||||||
|
@ -254,7 +285,7 @@ func TestNewValidationErrorStructArray(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.StructArray[1].OtherField, "")
|
doc, field := references(t, st, &st.StructArray[1].OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structArray[1].otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structArray[1].otherField: %s", assert.AnError))
|
||||||
|
@ -275,7 +306,7 @@ func TestNewValidationErrorStructPointerSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.StructPointerSlice[1].OtherField, "")
|
doc, field := references(t, st, &st.StructPointerSlice[1].OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerSlice[1].otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerSlice[1].otherField: %s", assert.AnError))
|
||||||
|
@ -296,7 +327,7 @@ func TestNewValidationErrorStructPointerArray(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.StructPointerArray[1].OtherField, "")
|
doc, field := references(t, st, &st.StructPointerArray[1].OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerArray[1].otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerArray[1].otherField: %s", assert.AnError))
|
||||||
|
@ -311,7 +342,7 @@ func TestNewValidationErrorPrimitiveSliceSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.PrimitiveSliceSlice[1][1], "")
|
doc, field := references(t, st, &st.PrimitiveSliceSlice[1][1], "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSliceSlice[1][1]: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSliceSlice[1][1]: %s", assert.AnError))
|
||||||
|
@ -328,7 +359,7 @@ func TestNewValidationErrorPrimitiveMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.PrimitiveMap, "ghi")
|
doc, field := references(t, st, &st.PrimitiveMap, "ghi")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.primitiveMap[\"ghi\"]: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.primitiveMap[\"ghi\"]: %s", assert.AnError))
|
||||||
|
@ -349,7 +380,7 @@ func TestNewValidationErrorStructPointerMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.StructPointerMap["ghi"].OtherField, "")
|
doc, field := references(t, st, &st.StructPointerMap["ghi"].OtherField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.structPointerMap[\"ghi\"].otherField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.structPointerMap[\"ghi\"].otherField: %s", assert.AnError))
|
||||||
|
@ -368,7 +399,7 @@ func TestNewValidationErrorNestedPrimitiveMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, st.NestedPointerMap["jkl"], "mno")
|
doc, field := references(t, st, st.NestedPointerMap["jkl"], "mno")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.nestedPointerMap[\"jkl\"][\"mno\"]: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.nestedPointerMap[\"jkl\"][\"mno\"]: %s", assert.AnError))
|
||||||
|
@ -383,7 +414,7 @@ func TestNewValidationErrorTopLevelIsNeedle(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, st, "")
|
doc, field := references(t, st, st, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
@ -396,7 +427,7 @@ func TestNewValidationErrorUntaggedField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.NoTagField, "")
|
doc, field := references(t, st, &st.NoTagField, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.NoTagField: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.NoTagField: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
@ -410,7 +441,7 @@ func TestNewValidationErrorOnlyYamlTaggedField(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, field := references(t, st, &st.OnlyYamlKey, "")
|
doc, field := references(t, st, &st.OnlyYamlKey, "")
|
||||||
err := newError(doc, field, assert.AnError)
|
err := newTraceError(doc, field, assert.AnError)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.onlyYamlKey: %s", assert.AnError))
|
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.onlyYamlKey: %s", assert.AnError))
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,19 @@ It validates documents that specify a set of constraints on their content.
|
||||||
*/
|
*/
|
||||||
package validation
|
package validation
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrStrategy is the strategy to use when encountering an error during validation.
|
||||||
|
type ErrStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// EvaluateAll continues evaluating all constraints even if one is not satisfied.
|
||||||
|
EvaluateAll ErrStrategy = iota
|
||||||
|
// FailFast stops validation on the first error.
|
||||||
|
FailFast
|
||||||
|
)
|
||||||
|
|
||||||
// NewValidator creates a new Validator.
|
// NewValidator creates a new Validator.
|
||||||
func NewValidator() *Validator {
|
func NewValidator() *Validator {
|
||||||
|
@ -24,21 +36,31 @@ type Validator struct{}
|
||||||
// Validatable is implemented by documents that can be validated.
|
// Validatable is implemented by documents that can be validated.
|
||||||
// It returns a list of constraints that must be satisfied for the document to be valid.
|
// It returns a list of constraints that must be satisfied for the document to be valid.
|
||||||
type Validatable interface {
|
type Validatable interface {
|
||||||
Constraints() []Constraint
|
Constraints() []*Constraint
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateOptions are the options to use when validating a document.
|
// ValidateOptions are the options to use when validating a document.
|
||||||
type ValidateOptions struct {
|
type ValidateOptions struct {
|
||||||
// FailFast stops validation on the first error.
|
// ErrStrategy is the strategy to use when encountering an error during validation.
|
||||||
FailFast bool
|
ErrStrategy ErrStrategy
|
||||||
|
// OverrideConstraints overrides the constraints to use for validation.
|
||||||
|
// If nil, the constraints returned by the document are used.
|
||||||
|
OverrideConstraints func() []*Constraint
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates a document using the given options.
|
// Validate validates a document using the given options.
|
||||||
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error {
|
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error {
|
||||||
|
var constraints func() []*Constraint
|
||||||
|
if opts.OverrideConstraints != nil {
|
||||||
|
constraints = opts.OverrideConstraints
|
||||||
|
} else {
|
||||||
|
constraints = doc.Constraints
|
||||||
|
}
|
||||||
|
|
||||||
var retErr error
|
var retErr error
|
||||||
for _, c := range doc.Constraints() {
|
for _, c := range constraints() {
|
||||||
if err := c.Satisfied(); err != nil {
|
if err := c.Satisfied(); err != nil {
|
||||||
if opts.FailFast {
|
if opts.ErrStrategy == FailFast {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
retErr = errors.Join(retErr, err)
|
retErr = errors.Join(retErr, err)
|
||||||
|
|
|
@ -14,34 +14,39 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var validDoc = func() *exampleDoc {
|
||||||
|
return &exampleDoc{
|
||||||
|
StrField: "abc",
|
||||||
|
NumField: 42,
|
||||||
|
MapField: &map[string]string{
|
||||||
|
"empty": "",
|
||||||
|
},
|
||||||
|
NotEmptyField: "certainly not.",
|
||||||
|
MatchRegexField: "abc",
|
||||||
|
OneOfField: "one",
|
||||||
|
OrLeftField: "left",
|
||||||
|
OrRightField: "right",
|
||||||
|
AndLeftField: "left",
|
||||||
|
AndRightField: "right",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidate(t *testing.T) {
|
func TestValidate(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
doc Validatable
|
doc func() *exampleDoc
|
||||||
opts ValidateOptions
|
opts ValidateOptions
|
||||||
wantErr bool
|
wantErr bool
|
||||||
errAssertion func(*assert.Assertions, error) bool
|
errAssertion func(*assert.Assertions, error) bool
|
||||||
}{
|
}{
|
||||||
"valid": {
|
"valid": {
|
||||||
doc: &exampleDoc{
|
doc: validDoc,
|
||||||
StrField: "abc",
|
|
||||||
NumField: 42,
|
|
||||||
MapField: &map[string]string{
|
|
||||||
"empty": "",
|
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not.",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
|
||||||
opts: ValidateOptions{},
|
opts: ValidateOptions{},
|
||||||
},
|
},
|
||||||
"strField is not abc": {
|
"strField is not abc": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "def",
|
doc := validDoc()
|
||||||
NumField: 42,
|
doc.StrField = "def"
|
||||||
MapField: &map[string]string{
|
return doc
|
||||||
"empty": "",
|
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not.",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
@ -50,14 +55,10 @@ func TestValidate(t *testing.T) {
|
||||||
opts: ValidateOptions{},
|
opts: ValidateOptions{},
|
||||||
},
|
},
|
||||||
"numField is not 42": {
|
"numField is not 42": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "abc",
|
doc := validDoc()
|
||||||
NumField: 43,
|
doc.NumField = 43
|
||||||
MapField: &map[string]string{
|
return doc
|
||||||
"empty": "",
|
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not.",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
@ -65,14 +66,11 @@ func TestValidate(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"multiple errors": {
|
"multiple errors": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "def",
|
doc := validDoc()
|
||||||
NumField: 43,
|
doc.StrField = "def"
|
||||||
MapField: &map[string]string{
|
doc.NumField = 43
|
||||||
"empty": "",
|
return doc
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not.",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
@ -82,75 +80,108 @@ func TestValidate(t *testing.T) {
|
||||||
opts: ValidateOptions{},
|
opts: ValidateOptions{},
|
||||||
},
|
},
|
||||||
"multiple errors, fail fast": {
|
"multiple errors, fail fast": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "def",
|
doc := validDoc()
|
||||||
NumField: 43,
|
doc.StrField = "def"
|
||||||
MapField: &map[string]string{
|
doc.NumField = 43
|
||||||
"empty": "",
|
return doc
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not.",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc")
|
return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc")
|
||||||
},
|
},
|
||||||
opts: ValidateOptions{
|
opts: ValidateOptions{
|
||||||
FailFast: true,
|
ErrStrategy: FailFast,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"map field is not empty": {
|
"map field is not empty": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "abc",
|
doc := validDoc()
|
||||||
NumField: 42,
|
doc.MapField = &map[string]string{
|
||||||
MapField: &map[string]string{
|
|
||||||
"empty": "haha!",
|
"empty": "haha!",
|
||||||
},
|
}
|
||||||
NotEmptyField: "certainly not.",
|
return doc
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
return assert.Contains(err.Error(), "validating exampleDoc.mapField[\"empty\"]: haha! must be empty")
|
return assert.Contains(err.Error(), "validating exampleDoc.mapField[\"empty\"]: haha! must be empty")
|
||||||
},
|
},
|
||||||
opts: ValidateOptions{
|
opts: ValidateOptions{
|
||||||
FailFast: true,
|
ErrStrategy: FailFast,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"empty field is not empty": {
|
"not empty field is empty": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "abc",
|
doc := validDoc()
|
||||||
NumField: 42,
|
doc.NotEmptyField = ""
|
||||||
MapField: &map[string]string{
|
return doc
|
||||||
"empty": "",
|
|
||||||
},
|
|
||||||
NotEmptyField: "",
|
|
||||||
MatchRegexField: "abc",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
|
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
|
||||||
},
|
},
|
||||||
opts: ValidateOptions{
|
opts: ValidateOptions{
|
||||||
FailFast: true,
|
ErrStrategy: FailFast,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"regex doesnt match": {
|
"regex doesnt match": {
|
||||||
doc: &exampleDoc{
|
doc: func() *exampleDoc {
|
||||||
StrField: "abc",
|
doc := validDoc()
|
||||||
NumField: 42,
|
doc.MatchRegexField = "dontmatch"
|
||||||
MapField: &map[string]string{
|
return doc
|
||||||
"empty": "",
|
|
||||||
},
|
|
||||||
NotEmptyField: "certainly not!",
|
|
||||||
MatchRegexField: "dontmatch",
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errAssertion: func(assert *assert.Assertions, err error) bool {
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
return assert.Contains(err.Error(), "validating exampleDoc.matchRegexField: dontmatch must match the pattern ^a.c$")
|
return assert.Contains(err.Error(), "validating exampleDoc.matchRegexField: dontmatch must match the pattern ^a.c$")
|
||||||
},
|
},
|
||||||
opts: ValidateOptions{
|
opts: ValidateOptions{
|
||||||
FailFast: true,
|
ErrStrategy: FailFast,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"field is not in 'oneof' values": {
|
||||||
|
doc: func() *exampleDoc {
|
||||||
|
doc := validDoc()
|
||||||
|
doc.OneOfField = "not in oneof"
|
||||||
|
return doc
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
return assert.Contains(err.Error(), "validating exampleDoc.oneOfField: not in oneof must be one of [one two three]")
|
||||||
|
},
|
||||||
|
opts: ValidateOptions{
|
||||||
|
ErrStrategy: FailFast,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"'or' violated": {
|
||||||
|
doc: func() *exampleDoc {
|
||||||
|
doc := validDoc()
|
||||||
|
doc.OrLeftField = "not left"
|
||||||
|
doc.OrRightField = "not right"
|
||||||
|
return doc
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
return assert.Contains(err.Error(), "at least one of the constraints must be satisfied:") &&
|
||||||
|
assert.Contains(err.Error(), "validating exampleDoc.orLeftField: not left must be equal to left") &&
|
||||||
|
assert.Contains(err.Error(), "validating exampleDoc.orRightField: not right must be equal to right")
|
||||||
|
},
|
||||||
|
opts: ValidateOptions{
|
||||||
|
ErrStrategy: FailFast,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"'and' violated": {
|
||||||
|
doc: func() *exampleDoc {
|
||||||
|
doc := validDoc()
|
||||||
|
doc.AndRightField = "not right"
|
||||||
|
return doc
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errAssertion: func(assert *assert.Assertions, err error) bool {
|
||||||
|
return assert.Contains(err.Error(), "all of the constraints must be satisfied:") &&
|
||||||
|
assert.Contains(err.Error(), "validating exampleDoc.andRightField: not right must be equal to right")
|
||||||
|
},
|
||||||
|
opts: ValidateOptions{
|
||||||
|
ErrStrategy: FailFast,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -160,7 +191,7 @@ func TestValidate(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
err := NewValidator().Validate(tc.doc, tc.opts)
|
err := NewValidator().Validate(tc.doc(), tc.opts)
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
require.Error(err)
|
require.Error(err)
|
||||||
if !tc.errAssertion(assert, err) {
|
if !tc.errAssertion(assert, err) {
|
||||||
|
@ -179,13 +210,18 @@ type exampleDoc struct {
|
||||||
MapField *map[string]string `json:"mapField"`
|
MapField *map[string]string `json:"mapField"`
|
||||||
NotEmptyField string `json:"notEmptyField"`
|
NotEmptyField string `json:"notEmptyField"`
|
||||||
MatchRegexField string `json:"matchRegexField"`
|
MatchRegexField string `json:"matchRegexField"`
|
||||||
|
OneOfField string `json:"oneOfField"`
|
||||||
|
OrLeftField string `json:"orLeftField"`
|
||||||
|
OrRightField string `json:"orRightField"`
|
||||||
|
AndLeftField string `json:"andLeftField"`
|
||||||
|
AndRightField string `json:"andRightField"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Constraints implements the Validatable interface.
|
// Constraints implements the Validatable interface.
|
||||||
func (d *exampleDoc) Constraints() []Constraint {
|
func (d *exampleDoc) Constraints() []*Constraint {
|
||||||
mapField := *(d.MapField)
|
mapField := *(d.MapField)
|
||||||
|
|
||||||
return []Constraint{
|
return []*Constraint{
|
||||||
d.strFieldNeedsToBeAbc().
|
d.strFieldNeedsToBeAbc().
|
||||||
WithFieldTrace(d, &d.StrField),
|
WithFieldTrace(d, &d.StrField),
|
||||||
Equal(d.NumField, 42).
|
Equal(d.NumField, 42).
|
||||||
|
@ -196,17 +232,95 @@ func (d *exampleDoc) Constraints() []Constraint {
|
||||||
WithFieldTrace(d, &d.NotEmptyField),
|
WithFieldTrace(d, &d.NotEmptyField),
|
||||||
MatchRegex(d.MatchRegexField, "^a.c$").
|
MatchRegex(d.MatchRegexField, "^a.c$").
|
||||||
WithFieldTrace(d, &d.MatchRegexField),
|
WithFieldTrace(d, &d.MatchRegexField),
|
||||||
|
OneOf(d.OneOfField, []string{"one", "two", "three"}).
|
||||||
|
WithFieldTrace(d, &d.OneOfField),
|
||||||
|
Or(
|
||||||
|
Equal(d.OrLeftField, "left").
|
||||||
|
WithFieldTrace(d, &d.OrLeftField),
|
||||||
|
Equal(d.OrRightField, "right").
|
||||||
|
WithFieldTrace(d, &d.OrRightField),
|
||||||
|
),
|
||||||
|
And(
|
||||||
|
EvaluateAll,
|
||||||
|
Equal(d.AndLeftField, "left").
|
||||||
|
WithFieldTrace(d, &d.AndLeftField),
|
||||||
|
Equal(d.AndRightField, "right").
|
||||||
|
WithFieldTrace(d, &d.AndRightField),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// StrFieldNeedsToBeAbc is an example for a custom constraint.
|
// StrFieldNeedsToBeAbc is an example for a custom constraint.
|
||||||
func (d *exampleDoc) strFieldNeedsToBeAbc() *Constraint {
|
func (d *exampleDoc) strFieldNeedsToBeAbc() *Constraint {
|
||||||
return &Constraint{
|
return &Constraint{
|
||||||
Satisfied: func() error {
|
Satisfied: func() *TreeError {
|
||||||
if d.StrField != "abc" {
|
if d.StrField != "abc" {
|
||||||
return fmt.Errorf("%s must be abc", d.StrField)
|
return NewErrorTree(
|
||||||
|
fmt.Errorf("%s must be abc", d.StrField),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOverrideConstraints(t *testing.T) {
|
||||||
|
overrideConstraints := func(t *testing.T, wantCalled bool) func() []*Constraint {
|
||||||
|
return func() []*Constraint {
|
||||||
|
if !wantCalled {
|
||||||
|
t.Fatal("overrideConstraints should not be called")
|
||||||
|
}
|
||||||
|
return []*Constraint{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
doc exampleDocToOverride
|
||||||
|
overrideFunc func() []*Constraint
|
||||||
|
wantOverrideCalled bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"override constraints": {
|
||||||
|
doc: exampleDocToOverride{},
|
||||||
|
overrideFunc: overrideConstraints(t, true),
|
||||||
|
wantOverrideCalled: true,
|
||||||
|
},
|
||||||
|
"do not override constraints": {
|
||||||
|
doc: exampleDocToOverride{
|
||||||
|
calledDocConstraints: true,
|
||||||
|
},
|
||||||
|
overrideFunc: nil,
|
||||||
|
wantOverrideCalled: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
validator := NewValidator()
|
||||||
|
err := validator.Validate(&tc.doc, ValidateOptions{
|
||||||
|
OverrideConstraints: tc.overrideFunc,
|
||||||
|
})
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(err)
|
||||||
|
} else {
|
||||||
|
require.NoError(err)
|
||||||
|
if tc.wantOverrideCalled {
|
||||||
|
assert.Equal(tc.doc.calledDocConstraints, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type exampleDocToOverride struct {
|
||||||
|
calledDocConstraints bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *exampleDocToOverride) Constraints() []*Constraint {
|
||||||
|
d.calledDocConstraints = true
|
||||||
|
return []*Constraint{}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue