cli: state file validation (#2523)

* re-use `ReadFromFile` in `CreateOrRead`

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* [wip]: add constraints

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* [wip] error formatting

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* wip

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* formatted error messages

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* state file validation

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* linter fixes

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* allow overriding the constraints

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* dont validate on read

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* add pre-create constraints

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* [wip]

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* finish pre-init validation test

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* finish post-init validation

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* use state file validation in CLI

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix apply tests

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* Update internal/validation/errors.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* use transformator for tests

* tidy

* use empty check directly

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* Update cli/internal/state/state.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* Update cli/internal/state/state.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* Update cli/internal/state/state.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* Update cli/internal/state/state.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* conditional validation per CSP

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* tidy

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix rebase

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* add default case

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* validate state-file as last input

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

---------

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
This commit is contained in:
Moritz Sanft 2023-11-03 15:47:03 +01:00 committed by GitHub
parent eaec73cca4
commit 744a605602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1779 additions and 247 deletions

View File

@ -419,12 +419,6 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
return nil, nil, err
}
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
if err != nil {
return nil, nil, err
}
// Check license
a.log.Debugf("Running license check")
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
@ -517,6 +511,27 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
}
// Read and validate state file
// This needs to be done as a last step, as we need to parse all other inputs to
// know which phases are skipped.
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
if err != nil {
return nil, nil, err
}
if a.flags.skipPhases.contains(skipInitPhase) {
// If the skipInit flag is set, we are in a state where the cluster
// has already been initialized and check against the respective constraints.
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
return nil, nil, err
}
} else {
// The cluster has not been initialized yet, so we check against the pre-init constraints.
if err := stateFile.Validate(state.PreInit, conf.GetProvider()); err != nil {
return nil, nil, err
}
}
return conf, stateFile, nil
}

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
@ -22,6 +23,54 @@ import (
"github.com/stretchr/testify/require"
)
// defaultStateFile returns a valid default state for testing.
func defaultStateFile() *state.State {
return &state.State{
Version: "v1",
Infrastructure: state.Infrastructure{
UID: "123",
Name: "test-cluster",
ClusterEndpoint: "192.0.2.1",
InClusterEndpoint: "192.0.2.1",
InitSecret: []byte{0x41},
APIServerCertSANs: []string{
"127.0.0.1",
"www.example.com",
},
IPCidrNode: "0.0.0.0/24",
Azure: &state.Azure{
ResourceGroup: "test-rg",
SubscriptionID: "test-sub",
NetworkSecurityGroupName: "test-nsg",
LoadBalancerName: "test-lb",
UserAssignedIdentity: "test-uami",
AttestationURL: "test-maaUrl",
},
GCP: &state.GCP{
ProjectID: "test-project",
IPCidrPod: "0.0.0.0/24",
},
},
ClusterValues: state.ClusterValues{
ClusterID: "deadbeef",
OwnerID: "deadbeef",
MeasurementSalt: []byte{0x41},
},
}
}
func defaultAzureStateFile() *state.State {
s := defaultStateFile()
s.Infrastructure.GCP = nil
return s
}
func defaultGCPStateFile() *state.State {
s := defaultStateFile()
s.Infrastructure.Azure = nil
return s
}
func TestParseApplyFlags(t *testing.T) {
require := require.New(t)
defaultFlags := func() *pflag.FlagSet {

View File

@ -202,6 +202,9 @@ func (c *createCmd) create(cmd *cobra.Command, applier cloudApplier, fileHandler
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
if err := stateFile.Validate(state.PreCreate, conf.GetProvider()); err != nil {
return fmt.Errorf("validating state file: %w", err)
}
stateFile = stateFile.SetInfrastructure(infraState)
if err := stateFile.WriteToFile(fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing state file: %w", err)

View File

@ -22,12 +22,21 @@ import (
"github.com/stretchr/testify/require"
)
// preCreateStateFile returns a state file satisfying the pre-create state file
// constraints.
func preCreateStateFile() *state.State {
s := defaultAzureStateFile()
s.ClusterValues = state.ClusterValues{}
s.Infrastructure = state.Infrastructure{}
return s
}
func TestCreate(t *testing.T) {
fsWithDefaultConfigAndState := func(require *require.Assertions, provider cloudprovider.Provider) afero.Fs {
fs := afero.NewMemMapFs()
file := file.NewHandler(fs)
require.NoError(file.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), provider)))
stateFile := state.New()
stateFile := preCreateStateFile()
switch provider {
case cloudprovider.GCP:
stateFile.SetInfrastructure(state.Infrastructure{GCP: &state.GCP{}})

View File

@ -59,6 +59,14 @@ func TestInitArgumentValidation(t *testing.T) {
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
}
// preInitStateFile returns a state file satisfying the pre-init state file
// constraints.
func preInitStateFile() *state.State {
s := defaultAzureStateFile()
s.ClusterValues = state.ClusterValues{}
return s
}
func TestInitialize(t *testing.T) {
respKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{
@ -101,24 +109,24 @@ func TestInitialize(t *testing.T) {
}{
"initialize some gcp instances": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
},
"initialize some azure instances": {
provider: cloudprovider.Azure,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
},
"initialize some qemu instances": {
provider: cloudprovider.QEMU,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
},
"non retriable error": {
provider: cloudprovider.QEMU,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{initErr: &nonRetriableError{err: assert.AnError}},
retriable: false,
masterSecretShouldExist: true,
@ -126,7 +134,7 @@ func TestInitialize(t *testing.T) {
},
"non retriable error with failed log collection": {
provider: cloudprovider.QEMU,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{
res: []*initproto.InitResponse{
{
@ -149,31 +157,24 @@ func TestInitialize(t *testing.T) {
masterSecretShouldExist: true,
wantErr: true,
},
/*
Tests currently disabled since we don't actually have validation for the state file yet
These tests cases only passed in the past because of unrelated errors in the test setup
TODO(AB#3492): Re-enable tests once state file validation is implemented
"state file with only version": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
"empty state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
*/
"invalid state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: "invalid"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
"empty state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
"no state file": {
provider: cloudprovider.GCP,
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
@ -184,7 +185,7 @@ func TestInitialize(t *testing.T) {
"init call fails": {
provider: cloudprovider.GCP,
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initErr: assert.AnError},
retriable: false,
@ -193,7 +194,7 @@ func TestInitialize(t *testing.T) {
},
"k8s version without v works": {
provider: cloudprovider.Azure,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
@ -203,7 +204,7 @@ func TestInitialize(t *testing.T) {
},
"outdated k8s patch version doesn't work": {
provider: cloudprovider.Azure,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
stateFile: preInitStateFile(),
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
v, err := semver.New(versions.SupportedK8sVersions()[0])

View File

@ -119,6 +119,10 @@ func (r *recoverCmd) recover(
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
if err := stateFile.Validate(state.PostInit, provider); err != nil {
return fmt.Errorf("validating state file: %w", err)
}
endpoint, err := r.parseEndpoint(stateFile)
if err != nil {
return err

View File

@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -159,7 +158,7 @@ func TestRecover(t *testing.T) {
))
require.NoError(fileHandler.WriteYAML(
constants.StateFilename,
state.New(),
defaultGCPStateFile(),
file.OptNone,
))

View File

@ -33,18 +33,10 @@ import (
)
func TestUpgradeApply(t *testing.T) {
defaultState := state.New().
SetInfrastructure(state.Infrastructure{
APIServerCertSANs: []string{},
UID: "uid",
Name: "kubernetes-uid", // default test cfg uses "kubernetes" prefix
InitSecret: []byte{0x42},
}).
SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}})
fsWithStateFileAndTfState := func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
return fh
}
@ -63,20 +55,20 @@ func TestUpgradeApply(t *testing.T) {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
gotState, err := state.ReadFromFile(fh, constants.StateFilename)
require.NoError(err)
assert.Equal("v1", gotState.Version)
assert.Equal(defaultState, gotState)
assert.Equal(defaultAzureStateFile(), gotState)
},
},
"id file and state file do not exist": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: func() file.Handler {
return file.NewHandler(afero.NewMemMapFs())
},
@ -90,7 +82,7 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true,
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"nodeVersion in progress error": {
@ -100,7 +92,7 @@ func TestUpgradeApply(t *testing.T) {
},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"helm other error": {
@ -110,7 +102,7 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{err: assert.AnError},
terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true,
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"abort": {
@ -140,7 +132,7 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
wantErr: true,
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"apply terraform error": {
@ -153,7 +145,7 @@ func TestUpgradeApply(t *testing.T) {
terraformDiff: true,
},
wantErr: true,
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"outdated K8s patch version": {
@ -167,7 +159,7 @@ func TestUpgradeApply(t *testing.T) {
require.NoError(t, err)
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
}(),
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
},
"outdated K8s version": {
@ -177,7 +169,7 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
customK8sVersion: "v1.20.0",
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
wantErr: true,
fh: fsWithStateFileAndTfState,
},
@ -191,6 +183,7 @@ func TestUpgradeApply(t *testing.T) {
skipPhases: skipPhases{
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
skipInitPhase: struct{}{},
},
yes: true,
},
@ -205,7 +198,7 @@ func TestUpgradeApply(t *testing.T) {
flags: applyFlags{
skipPhases: skipPhases{
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{},
skipK8sPhase: struct{}{}, skipInitPhase: struct{}{},
},
yes: true,
},
@ -219,10 +212,13 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader: &mockTerraformUpgrader{},
flags: applyFlags{
yes: true,
skipPhases: skipPhases{
skipInitPhase: struct{}{},
},
},
fh: func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultAzureStateFile()))
return fh
},
},
@ -230,7 +226,7 @@ func TestUpgradeApply(t *testing.T) {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
fh: fsWithStateFileAndTfState,
wantErr: true,
},

View File

@ -155,6 +155,9 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
if err := stateFile.Validate(state.PostInit, conf.GetProvider()); err != nil {
return fmt.Errorf("validating state file: %w", err)
}
ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
if err != nil {

View File

@ -48,7 +48,7 @@ func TestVerify(t *testing.T) {
formatter *stubAttDocFormatter
nodeEndpointFlag string
clusterIDFlag string
stateFile *state.State
stateFile func() *state.State
wantEndpoint string
skipConfigCreation bool
wantErr bool
@ -58,7 +58,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
stateFile: defaultGCPStateFile,
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -67,7 +67,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
stateFile: defaultAzureStateFile,
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -76,7 +76,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
stateFile: defaultGCPStateFile,
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{},
},
@ -84,56 +84,78 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
stateFile: func() *state.State {
s := defaultGCPStateFile()
s.Infrastructure.ClusterEndpoint = ""
return s
},
formatter: &stubAttDocFormatter{},
wantErr: true,
},
"endpoint from state file": {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{},
stateFile: func() *state.State {
s := defaultGCPStateFile()
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
return s
},
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{},
},
"override endpoint from details file": {
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.2:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: &state.State{Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
wantEndpoint: "192.0.2.2:1234",
formatter: &stubAttDocFormatter{},
stateFile: func() *state.State {
s := defaultGCPStateFile()
s.Infrastructure.ClusterEndpoint = "192.0.2.1"
return s
},
wantEndpoint: "192.0.2.2:1234",
formatter: &stubAttDocFormatter{},
},
"invalid endpoint": {
provider: cloudprovider.GCP,
nodeEndpointFlag: ":::::",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
stateFile: defaultGCPStateFile,
formatter: &stubAttDocFormatter{},
wantErr: true,
},
"neither owner id nor cluster id set": {
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
stateFile: func() *state.State {
s := defaultGCPStateFile()
s.ClusterValues.OwnerID = ""
s.ClusterValues.ClusterID = ""
return s
},
formatter: &stubAttDocFormatter{},
protoClient: &stubVerifyClient{},
wantErr: true,
},
"use owner id from state file": {
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
protoClient: &stubVerifyClient{},
stateFile: &state.State{ClusterValues: state.ClusterValues{OwnerID: zeroBase64}},
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
stateFile: func() *state.State {
s := defaultGCPStateFile()
s.ClusterValues.OwnerID = zeroBase64
return s
},
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
"config file not existing": {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
stateFile: defaultGCPStateFile,
formatter: &stubAttDocFormatter{},
skipConfigCreation: true,
wantErr: true,
@ -143,7 +165,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
stateFile: state.New(),
stateFile: defaultAzureStateFile,
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -152,7 +174,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: someErr},
stateFile: state.New(),
stateFile: defaultAzureStateFile,
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -161,7 +183,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
stateFile: defaultAzureStateFile,
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{formatErr: someErr},
wantErr: true,
@ -182,7 +204,7 @@ func TestVerify(t *testing.T) {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
}
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
require.NoError(tc.stateFile().WriteToFile(fileHandler, constants.StateFilename))
v := &verifyCmd{
fileHandler: fileHandler,

View File

@ -10,7 +10,9 @@ go_library(
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/state",
visibility = ["//cli:__subpackages__"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/file",
"//internal/validation",
"@cat_dario_mergo//:mergo",
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",
],
@ -18,9 +20,13 @@ go_library(
go_test(
name = "state_test",
srcs = ["state_test.go"],
srcs = [
"state_test.go",
"validation_test.go",
],
embed = [":state"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/constants",
"//internal/file",
"@com_github_siderolabs_talos_pkg_machinery//config/encoder",

View File

@ -19,7 +19,9 @@ import (
"os"
"dario.cat/mergo"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/validation"
)
const (
@ -27,25 +29,44 @@ const (
Version1 = "v1"
)
// ReadFromFile reads the state file at the given path and returns the state.
const (
// PreCreate are the constraints that should be enforced when the state file
// is validated before cloud infrastructure is created.
PreCreate ConstraintSet = iota
// PreInit are the constraints that should be enforced when the state file
// is validated before the first Constellation node is initialized.
PreInit
// PostInit are the constraints that should be enforced when the state file
// is validated after the cluster was initialized.
PostInit
)
// ConstraintSet defines which constraints the state file
// should be validated against.
type ConstraintSet int
// ReadFromFile reads the state file at the given path and validates it.
// If the state file is valid, the state is returned. Otherwise, an error
// describing why the validation failed is returned.
func ReadFromFile(fileHandler file.Handler, path string) (*State, error) {
state := &State{}
if err := fileHandler.ReadYAML(path, &state); err != nil {
return nil, fmt.Errorf("reading state file: %w", err)
}
return state, nil
}
// CreateOrRead reads the state file at the given path, if it exists, and returns the state.
// If the file does not exist, a new state is created and written to disk.
func CreateOrRead(fileHandler file.Handler, path string) (*State, error) {
state := &State{}
if err := fileHandler.ReadYAML(path, &state); err != nil {
state, err := ReadFromFile(fileHandler, path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("reading state file: %w", err)
}
state = New()
return state, state.WriteToFile(fileHandler, path)
newState := New()
return newState, newState.WriteToFile(fileHandler, path)
}
return state, nil
}
@ -186,6 +207,349 @@ func (s *State) Merge(other *State) (*State, error) {
return s, nil
}
/*
Validate validates the state against the given constraint set and CSP, which can be one of
- PreCreate, which is the constraint set that should be enforced before "constellation create" is run.
- PreInit, which is the constraint set that should be enforced before "constellation apply" is run.
- PostInit, which is the constraint set that should be enforced after "constellation apply" is run.
*/
func (s *State) Validate(constraintSet ConstraintSet, csp cloudprovider.Provider) error {
v := validation.NewValidator()
switch constraintSet {
case PreCreate:
return v.Validate(s, validation.ValidateOptions{
OverrideConstraints: s.preCreateConstraints,
})
case PreInit:
return v.Validate(s, validation.ValidateOptions{
OverrideConstraints: s.preInitConstraints,
})
case PostInit:
return v.Validate(s, validation.ValidateOptions{
OverrideConstraints: s.postInitConstraints(csp),
})
default:
return errors.New("unknown constraint set")
}
}
// preCreateConstraints are the constraints on the state that should be enforced
// before a Constellation cluster is created.
//
// The constraints check if the state file version is valid,
// and if all fields are empty, which is a requirement pre-create.
func (s *State) preCreateConstraints() []*validation.Constraint {
return []*validation.Constraint{
// state version needs to be accepted by the parsing CLI.
validation.OneOf(s.Version, []string{Version1}).
WithFieldTrace(s, &s.Version),
// Infrastructure must be empty.
// As the infrastructure struct contains slices, we cannot use the
// Empty constraint on the entire struct. Instead, we need to check
// each field individually.
validation.Empty(s.Infrastructure.UID).
WithFieldTrace(s, &s.Infrastructure.UID),
validation.Empty(s.Infrastructure.ClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
validation.Empty(s.Infrastructure.InClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
validation.Empty(s.Infrastructure.Name).
WithFieldTrace(s, &s.Infrastructure.Name),
validation.Empty(s.Infrastructure.IPCidrNode).
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
validation.EmptySlice(s.Infrastructure.APIServerCertSANs).
WithFieldTrace(s, &s.Infrastructure.APIServerCertSANs),
validation.EmptySlice(s.Infrastructure.InitSecret).
WithFieldTrace(s, &s.Infrastructure.InitSecret),
// ClusterValues must be empty.
// As the clusterValues struct contains slices, we cannot use the
// Empty constraint on the entire struct. Instead, we need to check
// each field individually.
validation.Empty(s.ClusterValues.ClusterID).
WithFieldTrace(s, &s.ClusterValues.ClusterID),
validation.Empty(s.ClusterValues.OwnerID).
WithFieldTrace(s, &s.ClusterValues.OwnerID),
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
}
}
// preInitConstraints are the constraints on the state that should be enforced
// *before* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
//
// The constraints check if the infrastructure state is valid, and if the cluster values
// are empty, which is required for the cluster to initialize correctly.
func (s *State) preInitConstraints() []*validation.Constraint {
return []*validation.Constraint{
// state version needs to be accepted by the parsing CLI.
validation.OneOf(s.Version, []string{Version1}).
WithFieldTrace(s, &s.Version),
// infrastructure must be valid.
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
validation.Or(
validation.DNSName(s.Infrastructure.ClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
),
// in-cluster endpoint needs to be a valid DNS name or IP address.
validation.Or(
validation.DNSName(s.Infrastructure.InClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
),
// Node IP Cidr needs to be a valid CIDR range.
validation.CIDR(s.Infrastructure.IPCidrNode).
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
// UID needs to be filled.
validation.NotEmpty(s.Infrastructure.UID).
WithFieldTrace(s, &s.Infrastructure.UID),
// Name needs to be filled.
validation.NotEmpty(s.Infrastructure.Name).
WithFieldTrace(s, &s.Infrastructure.Name),
// GCP values need to be nil, empty, or valid.
validation.Or(
validation.Or(
// nil.
validation.Equal(s.Infrastructure.GCP, nil).
WithFieldTrace(s, &s.Infrastructure.GCP),
// empty.
validation.IfNotNil(
s.Infrastructure.GCP,
func() *validation.Constraint {
return validation.Empty(*s.Infrastructure.GCP).
WithFieldTrace(s, &s.Infrastructure.GCP)
},
),
),
// valid.
validation.IfNotNil(
s.Infrastructure.GCP,
func() *validation.Constraint {
return validation.And(
validation.EvaluateAll,
// ProjectID needs to be filled.
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
// Pod IP Cidr needs to be a valid CIDR range.
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
)
},
),
),
// Azure values need to be nil, empty, or valid.
validation.Or(
validation.Or(
// nil.
validation.Equal(s.Infrastructure.Azure, nil).
WithFieldTrace(s, &s.Infrastructure.Azure),
// empty.
validation.IfNotNil(
s.Infrastructure.Azure,
func() *validation.Constraint {
return validation.And(
validation.EvaluateAll,
validation.Empty(s.Infrastructure.Azure.ResourceGroup).
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
validation.Empty(s.Infrastructure.Azure.SubscriptionID).
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
validation.Empty(s.Infrastructure.Azure.NetworkSecurityGroupName).
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
validation.Empty(s.Infrastructure.Azure.LoadBalancerName).
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
validation.Empty(s.Infrastructure.Azure.UserAssignedIdentity).
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
validation.Empty(s.Infrastructure.Azure.AttestationURL).
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
)
},
),
),
// valid.
validation.IfNotNil(
s.Infrastructure.Azure,
func() *validation.Constraint {
return validation.And(
validation.EvaluateAll,
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
)
},
),
),
// ClusterValues must be empty.
// As the clusterValues struct contains slices, we cannot use the
// Empty constraint on the entire struct. Instead, we need to check
// each field individually.
validation.Empty(s.ClusterValues.ClusterID).
WithFieldTrace(s, &s.ClusterValues.ClusterID),
validation.Empty(s.ClusterValues.OwnerID).
WithFieldTrace(s, &s.ClusterValues.OwnerID),
validation.EmptySlice(s.ClusterValues.MeasurementSalt).
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
}
}
// postInitConstraints are the constraints on the state that should be enforced
// *after* a Constellation cluster is initialized. (i.e. before "constellation apply" is run.)
//
// The constraints check if the infrastructure state and cluster state
// is valid, so that the cluster can be used correctly.
func (s *State) postInitConstraints(csp cloudprovider.Provider) func() []*validation.Constraint {
return func() []*validation.Constraint {
constraints := []*validation.Constraint{
// state version needs to be accepted by the parsing CLI.
validation.OneOf(s.Version, []string{Version1}).
WithFieldTrace(s, &s.Version),
// infrastructure must be valid.
// out-of-cluster endpoint needs to be a valid DNS name or IP address.
validation.Or(
validation.DNSName(s.Infrastructure.ClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
validation.IPAddress(s.Infrastructure.ClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.ClusterEndpoint),
),
// in-cluster endpoint needs to be a valid DNS name or IP address.
validation.Or(
validation.DNSName(s.Infrastructure.InClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
validation.IPAddress(s.Infrastructure.InClusterEndpoint).
WithFieldTrace(s, &s.Infrastructure.InClusterEndpoint),
),
// Node IP Cidr needs to be a valid CIDR range.
validation.CIDR(s.Infrastructure.IPCidrNode).
WithFieldTrace(s, &s.Infrastructure.IPCidrNode),
// UID needs to be filled.
validation.NotEmpty(s.Infrastructure.UID).
WithFieldTrace(s, &s.Infrastructure.UID),
// Name needs to be filled.
validation.NotEmpty(s.Infrastructure.Name).
WithFieldTrace(s, &s.Infrastructure.Name),
// ClusterValues need to be valid.
// ClusterID needs to be filled.
validation.NotEmpty(s.ClusterValues.ClusterID).
WithFieldTrace(s, &s.ClusterValues.ClusterID),
// OwnerID needs to be filled.
validation.NotEmpty(s.ClusterValues.OwnerID).
WithFieldTrace(s, &s.ClusterValues.OwnerID),
// MeasurementSalt needs to be filled.
validation.NotEmptySlice(s.ClusterValues.MeasurementSalt).
WithFieldTrace(s, &s.ClusterValues.MeasurementSalt),
}
switch csp {
case cloudprovider.Azure:
constraints = append(constraints,
// GCP values need to be nil or empty.
validation.Or(
validation.Equal(s.Infrastructure.GCP, nil).
WithFieldTrace(s, &s.Infrastructure.GCP),
validation.IfNotNil(
s.Infrastructure.GCP,
func() *validation.Constraint {
return validation.Empty(s.Infrastructure.GCP).
WithFieldTrace(s, &s.Infrastructure.GCP)
},
)),
// Azure values need to be valid.
validation.IfNotNil(
s.Infrastructure.Azure,
func() *validation.Constraint {
return validation.And(
validation.EvaluateAll,
validation.NotEmpty(s.Infrastructure.Azure.ResourceGroup).
WithFieldTrace(s, &s.Infrastructure.Azure.ResourceGroup),
validation.NotEmpty(s.Infrastructure.Azure.SubscriptionID).
WithFieldTrace(s, &s.Infrastructure.Azure.SubscriptionID),
validation.NotEmpty(s.Infrastructure.Azure.NetworkSecurityGroupName).
WithFieldTrace(s, &s.Infrastructure.Azure.NetworkSecurityGroupName),
validation.NotEmpty(s.Infrastructure.Azure.LoadBalancerName).
WithFieldTrace(s, &s.Infrastructure.Azure.LoadBalancerName),
validation.NotEmpty(s.Infrastructure.Azure.UserAssignedIdentity).
WithFieldTrace(s, &s.Infrastructure.Azure.UserAssignedIdentity),
validation.NotEmpty(s.Infrastructure.Azure.AttestationURL).
WithFieldTrace(s, &s.Infrastructure.Azure.AttestationURL),
)
},
),
)
case cloudprovider.GCP:
constraints = append(constraints,
// Azure values need to be nil or empty.
validation.Or(
validation.Equal(s.Infrastructure.Azure, nil).
WithFieldTrace(s, &s.Infrastructure.Azure),
validation.IfNotNil(
s.Infrastructure.Azure,
func() *validation.Constraint {
return validation.Empty(s.Infrastructure.Azure).
WithFieldTrace(s, &s.Infrastructure.Azure)
},
)),
// GCP values need to be valid.
validation.IfNotNil(
s.Infrastructure.GCP,
func() *validation.Constraint {
return validation.And(
validation.EvaluateAll,
// ProjectID needs to be filled.
validation.NotEmpty(s.Infrastructure.GCP.ProjectID).
WithFieldTrace(s, &s.Infrastructure.GCP.ProjectID),
// Pod IP Cidr needs to be a valid CIDR range.
validation.CIDR(s.Infrastructure.GCP.IPCidrPod).
WithFieldTrace(s, &s.Infrastructure.GCP.IPCidrPod),
)
},
),
)
default:
constraints = append(constraints,
// GCP values need to be nil or empty.
validation.Or(
validation.Equal(s.Infrastructure.GCP, nil).
WithFieldTrace(s, &s.Infrastructure.GCP),
validation.IfNotNil(
s.Infrastructure.GCP,
func() *validation.Constraint {
return validation.Empty(s.Infrastructure.GCP).
WithFieldTrace(s, &s.Infrastructure.GCP)
},
)),
// Azure values need to be nil or empty.
validation.Or(
validation.Equal(s.Infrastructure.Azure, nil).
WithFieldTrace(s, &s.Infrastructure.Azure),
validation.IfNotNil(
s.Infrastructure.Azure,
func() *validation.Constraint {
return validation.Empty(s.Infrastructure.Azure).
WithFieldTrace(s, &s.Infrastructure.Azure)
},
)),
)
}
return constraints
}
}
// Constraints is a no-op implementation to fulfill the "Validatable" interface.
func (s *State) Constraints() []*validation.Constraint {
return []*validation.Constraint{}
}
// HexBytes is a byte slice that is marshalled to and from a hex string.
type HexBytes []byte

View File

@ -18,18 +18,21 @@ import (
"gopkg.in/yaml.v3"
)
// defaultState returns a valid default state for testing.
func defaultState() *State {
return &State{
Version: "v1",
Infrastructure: Infrastructure{
UID: "123",
ClusterEndpoint: "test-cluster-endpoint",
InitSecret: []byte{0x41},
UID: "123",
Name: "test-cluster",
ClusterEndpoint: "0.0.0.0",
InClusterEndpoint: "0.0.0.0",
InitSecret: []byte{0x41},
APIServerCertSANs: []string{
"api-server-cert-san-test",
"api-server-cert-san-test-2",
"127.0.0.1",
"www.example.com",
},
IPCidrNode: "test-cidr-node",
IPCidrNode: "0.0.0.0/24",
Azure: &Azure{
ResourceGroup: "test-rg",
SubscriptionID: "test-sub",
@ -40,7 +43,7 @@ func defaultState() *State {
},
GCP: &GCP{
ProjectID: "test-project",
IPCidrPod: "test-cidr-pod",
IPCidrPod: "0.0.0.0/24",
},
},
ClusterValues: ClusterValues{
@ -51,6 +54,18 @@ func defaultState() *State {
}
}
func defaultAzureState() *State {
s := defaultState()
s.Infrastructure.GCP = nil
return s
}
func defaultGCPState() *State {
s := defaultState()
s.Infrastructure.Azure = nil
return s
}
func TestWriteToFile(t *testing.T) {
testCases := map[string]struct {
state *State

View File

@ -0,0 +1,379 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package state
import (
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPreCreateValidation(t *testing.T) {
testCases := map[string]struct {
stateFile func() *State
wantErr bool
errAssertions func(a *assert.Assertions, err error)
}{
"valid": {
stateFile: func() *State {
return &State{
Version: Version1,
}
},
},
"invalid version": {
stateFile: func() *State {
return &State{
Version: "invalid",
}
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
},
},
"infrastructure not empty": {
stateFile: func() *State {
return &State{
Version: Version1,
Infrastructure: Infrastructure{
ClusterEndpoint: "test",
},
}
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: test must be empty")
},
},
"cluster values not empty": {
stateFile: func() *State {
return &State{
Version: Version1,
ClusterValues: ClusterValues{
ClusterID: "test",
},
}
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.clusterValues.clusterID: test must be empty")
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := tc.stateFile().Validate(PreCreate, cloudprovider.Azure)
if tc.wantErr {
require.Error(t, err)
if tc.errAssertions != nil {
tc.errAssertions(assert.New(t), err)
}
} else {
require.NoError(t, err)
}
})
}
}
func TestPreInitValidation(t *testing.T) {
validPreInitState := func() *State {
s := defaultState()
s.ClusterValues = ClusterValues{}
return s
}
testCases := map[string]struct {
stateFile func() *State
wantErr bool
errAssertions func(a *assert.Assertions, err error)
}{
"valid": {
stateFile: validPreInitState,
},
"invalid version": {
stateFile: func() *State {
s := validPreInitState()
s.Version = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
},
},
"cluster endpoint invalid": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.ClusterEndpoint = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
},
},
"in-cluster endpoint invalid": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.InClusterEndpoint = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
},
},
"node ip cidr invalid": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.IPCidrNode = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
},
},
"uid empty": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.UID = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
},
},
"name empty": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.Name = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
},
},
"gcp empty": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.GCP = &GCP{}
return s
},
},
"gcp nil": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.GCP = nil
return s
},
},
"gcp invalid": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.GCP.IPCidrPod = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.gcp.ipCidrPod: invalid must be a valid CIDR")
},
},
"azure empty": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.Azure = &Azure{}
return s
},
},
"azure nil": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.Azure = nil
return s
},
},
"azure invalid": {
stateFile: func() *State {
s := validPreInitState()
s.Infrastructure.Azure.NetworkSecurityGroupName = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.azure.networkSecurityGroupName: must not be empty")
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := tc.stateFile().Validate(PreInit, cloudprovider.Azure)
if tc.wantErr {
require.Error(t, err)
if tc.errAssertions != nil {
tc.errAssertions(assert.New(t), err)
}
} else {
require.NoError(t, err)
}
})
}
}
func TestPostInitValidation(t *testing.T) {
testCases := map[string]struct {
stateFile func() *State
provider cloudprovider.Provider
wantErr bool
errAssertions func(a *assert.Assertions, err error)
}{
"valid": {
stateFile: defaultGCPState,
provider: cloudprovider.GCP,
},
"invalid version": {
stateFile: func() *State {
s := defaultState()
s.Version = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.version: invalid must be one of [v1]")
},
},
"cluster endpoint invalid": {
stateFile: func() *State {
s := defaultState()
s.Infrastructure.ClusterEndpoint = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid DNS name")
a.Contains(err.Error(), "validating State.infrastructure.clusterEndpoint: invalid must be a valid IP address")
},
},
"in-cluster endpoint invalid": {
stateFile: func() *State {
s := defaultState()
s.Infrastructure.InClusterEndpoint = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid DNS name")
a.Contains(err.Error(), "validating State.infrastructure.inClusterEndpoint: invalid must be a valid IP address")
},
},
"node ip cidr invalid": {
stateFile: func() *State {
s := defaultState()
s.Infrastructure.IPCidrNode = "invalid"
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.ipCidrNode: invalid must be a valid CIDR")
},
},
"uid empty": {
stateFile: func() *State {
s := defaultState()
s.Infrastructure.UID = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.uid: must not be empty")
},
},
"name empty": {
stateFile: func() *State {
s := defaultState()
s.Infrastructure.Name = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.infrastructure.name: must not be empty")
},
},
"gcp valid": {
stateFile: func() *State {
s := defaultGCPState()
return s
},
provider: cloudprovider.GCP,
},
"azure valid": {
stateFile: func() *State {
s := defaultAzureState()
return s
},
provider: cloudprovider.Azure,
},
"gcp, azure not nil": {
stateFile: func() *State {
s := defaultState()
return s
},
provider: cloudprovider.GCP,
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "must be equal to <nil>")
a.Contains(err.Error(), "must be empty")
},
},
"azure, gcp not nil": {
stateFile: func() *State {
s := defaultState()
return s
},
provider: cloudprovider.Azure,
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "must be equal to <nil>")
a.Contains(err.Error(), "must be empty")
},
},
"cluster values invalid": {
stateFile: func() *State {
s := defaultState()
s.ClusterValues.ClusterID = ""
return s
},
wantErr: true,
errAssertions: func(a *assert.Assertions, err error) {
a.Contains(err.Error(), "validating State.clusterValues.clusterID: must not be empty")
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := tc.stateFile().Validate(PostInit, tc.provider)
if tc.wantErr {
require.Error(t, err)
if tc.errAssertions != nil {
tc.errAssertions(assert.New(t), err)
}
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -15,6 +15,7 @@ go_library(
go_test(
name = "validation_test",
srcs = [
"constraints_test.go",
"errors_test.go",
"validation_test.go",
],

View File

@ -8,6 +8,7 @@ package validation
import (
"fmt"
"net"
"reflect"
"regexp"
)
@ -15,8 +16,10 @@ import (
// Constraint is a constraint on a document or a field of a document.
type Constraint struct {
// Satisfied returns no error if the constraint is satisfied.
// Otherwise, it returns the reason why the constraint is not satisfied.
Satisfied func() error
// Otherwise, it returns the reason why the constraint is not satisfied,
// possibly including its child errors, i.e., errors returned by constraints
// that are embedded in this constraint.
Satisfied func() *TreeError
}
/*
@ -36,7 +39,7 @@ Example for a pointer field:
Due to Go's addressability limititations regarding maps, if a map field is
to be validated, WithMapFieldTrace must be used instead of WithFieldTrace.
*/
func (c *Constraint) WithFieldTrace(doc any, field any) Constraint {
func (c *Constraint) WithFieldTrace(doc any, field any) *Constraint {
// we only want to dereference the needle once to dereference the pointer
// used to pass it to the function without losing reference to it, as the
// needle could be an arbitrarily long chain of pointers. The same
@ -69,7 +72,7 @@ Example:
For non-map fields, WithFieldTrace should be used instead of WithMapFieldTrace.
*/
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constraint {
func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) *Constraint {
// we only want to dereference the needle once to dereference the pointer
// used to pass it to the function without losing reference to it, as the
// needle could be an arbitrarily long chain of pointers. The same
@ -91,11 +94,11 @@ func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constr
}
// withTrace wraps the constraint's error message with a well-formatted trace.
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
return Constraint{
Satisfied: func() error {
func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if err := c.Satisfied(); err != nil {
return newError(docRef, fieldRef, err)
return newTraceError(docRef, fieldRef, err)
}
return nil
},
@ -105,49 +108,207 @@ func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint {
// MatchRegex is a constraint that if s matches regex.
func MatchRegex(s string, regex string) *Constraint {
return &Constraint{
Satisfied: func() error {
Satisfied: func() *TreeError {
if !regexp.MustCompile(regex).MatchString(s) {
return fmt.Errorf("%s must match the pattern %s", s, regex)
return NewErrorTree(fmt.Errorf("%s must match the pattern %s", s, regex))
}
return nil
},
}
}
// Equal is a constraint that if s is equal to t.
// Equal is a constraint that checks if s is equal to t.
func Equal[T comparable](s T, t T) *Constraint {
return &Constraint{
Satisfied: func() error {
Satisfied: func() *TreeError {
if s != t {
return fmt.Errorf("%v must be equal to %v", s, t)
return NewErrorTree(fmt.Errorf("%v must be equal to %v", s, t))
}
return nil
},
}
}
// NotEmpty is a constraint that if s is not empty.
func NotEmpty[T comparable](s T) *Constraint {
// NotEqual is a constraint that checks if s is not equal to t.
func NotEqual[T comparable](s T, t T) *Constraint {
return &Constraint{
Satisfied: func() error {
var zero T
if s == zero {
return fmt.Errorf("%v must not be empty", s)
Satisfied: func() *TreeError {
if Equal(s, t).Satisfied() == nil {
return NewErrorTree(fmt.Errorf("%v must not be equal to %v", s, t))
}
return nil
},
}
}
// Empty is a constraint that if s is empty.
// Empty is a constraint that checks if s is empty.
func Empty[T comparable](s T) *Constraint {
return &Constraint{
Satisfied: func() error {
Satisfied: func() *TreeError {
var zero T
if s != zero {
return fmt.Errorf("%v must be empty", s)
return NewErrorTree(fmt.Errorf("%v must be empty", s))
}
return nil
},
}
}
// NotEmpty is a constraint that checks if s is not empty.
func NotEmpty[T comparable](s T) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if Empty(s).Satisfied() == nil {
return NewErrorTree(fmt.Errorf("must not be empty"))
}
return nil
},
}
}
// OneOf is a constraint that s is in the set of values p.
func OneOf[T comparable](s T, p []T) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
for _, v := range p {
if s == v {
return nil
}
}
return NewErrorTree(fmt.Errorf("%v must be one of %v", s, p))
},
}
}
// IPAddress is a constraint that checks if s is a valid IP address.
func IPAddress(s string) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if net.ParseIP(s) == nil {
return NewErrorTree(fmt.Errorf("%s must be a valid IP address", s))
}
return nil
},
}
}
// CIDR is a constraint that checks if s is a valid CIDR.
func CIDR(s string) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if _, _, err := net.ParseCIDR(s); err != nil {
return NewErrorTree(fmt.Errorf("%s must be a valid CIDR", s))
}
return nil
},
}
}
// DNSName is a constraint that checks if s is a valid DNS name.
func DNSName(s string) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if _, err := net.LookupHost(s); err != nil {
return NewErrorTree(fmt.Errorf("%s must be a valid DNS name", s))
}
return nil
},
}
}
// EmptySlice is a constraint that checks if s is an empty slice.
func EmptySlice[T comparable](s []T) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if len(s) != 0 {
return NewErrorTree(fmt.Errorf("%v must be empty", s))
}
return nil
},
}
}
// NotEmptySlice is a constraint that checks if slice s is not empty.
func NotEmptySlice[T comparable](s []T) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if EmptySlice(s).Satisfied() == nil {
return NewErrorTree(fmt.Errorf("must not be empty"))
}
return nil
},
}
}
// All is a constraint that checks if all elements of s satisfy the constraint c.
// The constraint should be parametric in regards to the index of the element in s,
// as well as the element itself.
func All[T comparable](s []T, c func(i int, v T) *Constraint) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
for i, v := range s {
if err := c(i, v).Satisfied(); err != nil {
retErr.appendChild(err)
}
}
if len(retErr.children) == 0 {
return nil
}
return retErr
},
}
}
// And groups multiple constraints in an "and" relation and fails according to the given strategy.
func And(errStrat ErrStrategy, constraints ...*Constraint) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
retErr := NewErrorTree(fmt.Errorf("all of the constraints must be satisfied: "))
for _, constraint := range constraints {
if err := constraint.Satisfied(); err != nil {
if errStrat == FailFast {
return err
}
retErr.appendChild(err)
}
}
if len(retErr.children) == 0 {
return nil
}
return retErr
},
}
}
// Or groups multiple constraints in an "or" relation.
func Or(constraints ...*Constraint) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
retErr := NewErrorTree(fmt.Errorf("at least one of the constraints must be satisfied: "))
for _, constraint := range constraints {
err := constraint.Satisfied()
if err == nil {
return nil
}
retErr.appendChild(err)
}
if len(retErr.children) == 0 {
return nil
}
return retErr
},
}
}
// IfNotNil evaluates a constraint if and only if s is not nil.
func IfNotNil[T comparable](s *T, c func() *Constraint) *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
if s == nil {
return nil
}
return c().Satisfied()
},
}
}

View File

@ -0,0 +1,290 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package validation
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIPAddress(t *testing.T) {
testCases := map[string]struct {
ip string
wantErr bool
}{
"valid ipv4": {
ip: "127.0.0.1",
},
"valid ipv6": {
ip: "2001:db8::68",
},
"invalid": {
ip: "invalid",
wantErr: true,
},
"empty": {
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := IPAddress(tc.ip).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestCIDR(t *testing.T) {
testCases := map[string]struct {
cidr string
wantErr bool
}{
"valid ipv4": {
cidr: "192.0.2.0/24",
},
"valid ipv6": {
cidr: "2001:db8::/32",
},
"invalid": {
cidr: "invalid",
wantErr: true,
},
"empty": {
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := CIDR(tc.cidr).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestDNSName(t *testing.T) {
testCases := map[string]struct {
dnsName string
wantErr bool
}{
"valid": {
dnsName: "example.com",
},
"invalid": {
dnsName: "invalid",
wantErr: true,
},
"empty": {
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := DNSName(tc.dnsName).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestEmptySlice(t *testing.T) {
testCases := map[string]struct {
s []any
wantErr bool
}{
"valid": {
s: []any{},
},
"nil": {
s: nil,
},
"invalid": {
s: []any{1},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := EmptySlice(tc.s).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestNotEmptySlice(t *testing.T) {
testCases := map[string]struct {
s []any
wantErr bool
}{
"valid": {
s: []any{1},
},
"invalid": {
s: []any{},
wantErr: true,
},
"nil": {
s: nil,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := NotEmptySlice(tc.s).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestAll(t *testing.T) {
c := func(i int, s string) *Constraint {
return Equal(s, "abc")
}
testCases := map[string]struct {
s []string
wantErr bool
}{
"valid": {
s: []string{"abc", "abc", "abc"},
},
"nil": {
s: nil,
},
"empty": {
s: []string{},
},
"all are invalid": {
s: []string{"def", "lol"},
wantErr: true,
},
"one is invalid": {
s: []string{"abc", "def"},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := All(tc.s, c).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestNotEqual(t *testing.T) {
testCases := map[string]struct {
a any
b any
wantErr bool
}{
"valid": {
a: "abc",
b: "def",
},
"invalid": {
a: "abc",
b: "abc",
wantErr: true,
},
"empty": {
wantErr: true,
},
"one empty": {
a: "abc",
b: "",
},
"one nil": {
a: "abc",
b: nil,
},
"both nil": {
a: nil,
b: nil,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := NotEqual(tc.a, tc.b).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}
func TestIfNotNil(t *testing.T) {
testCases := map[string]struct {
a *int
c func() *Constraint
wantErr bool
}{
"valid": {
a: new(int),
c: func() *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
return nil
},
}
},
},
"nil": {
a: nil,
c: func() *Constraint {
return &Constraint{
Satisfied: func() *TreeError {
t.Fatal("should not be called")
return nil
},
}
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
err := IfNotNil(tc.a, tc.c).Satisfied()
if tc.wantErr {
require.Error(t, err)
} else {
require.Nil(t, err)
}
})
}
}

View File

@ -13,42 +13,86 @@ import (
"strings"
)
// Error is returned when a document is not valid.
type Error struct {
Path string
Err error
// TreeError is returned when a document is not valid.
// It contains the path to the field that failed validation, the error
// that occurred, as well as a list of child errors, as one constraint
// can embed multiple other constraints, e.g. in an OR.
type TreeError struct {
path string
err error
children []*TreeError
}
// NewErrorTree creates a new error tree from the given error.
func NewErrorTree(err error) *TreeError {
return &TreeError{
err: err,
children: []*TreeError{},
}
}
/*
newError creates a new validation Error.
newTraceError creates a new validation error, traced to a field.
To find the path to the exported field that failed validation, it traverses "doc"
recursively until it finds a field in "doc" that matches the reference to "field".
*/
func newError(doc, field referenceableValue, errMsg error) *Error {
func newTraceError(doc, field referenceableValue, errMsg error) *TreeError {
// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found
path, err := traverse(doc, field, newPathBuilder(doc._type.Name()))
if err != nil {
return &Error{
Path: "unknown",
Err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
return &TreeError{
path: "unknown",
err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg),
}
}
return &Error{
Path: path,
Err: errMsg,
return &TreeError{
path: path,
err: errMsg,
children: []*TreeError{},
}
}
// Error implements the error interface.
func (e *Error) Error() string {
return fmt.Sprintf("validating %s: %s", e.Path, e.Err)
func (e *TreeError) Error() string {
return e.format(0)
}
// Unwrap implements the error interface.
func (e *Error) Unwrap() error {
return e.Err
func (e *TreeError) Unwrap() error {
return e.err
}
// format formats the error tree and all of its children.
func (e *TreeError) format(indent int) string {
var sb strings.Builder
if e.path != "" {
sb.WriteString(fmt.Sprintf(
"%svalidating %s: %s",
strings.Repeat(" ", indent),
e.path,
e.err,
))
} else {
sb.WriteString(fmt.Sprintf(
"%s%s",
strings.Repeat(" ", indent),
e.err,
))
}
for _, child := range e.children {
sb.WriteString(fmt.Sprintf(
"\n%s",
child.format(indent+1),
))
}
return sb.String()
}
// appendChild adds the given child error to the tree.
func (e *TreeError) appendChild(child *TreeError) {
e.children = append(e.children, child)
}
/*
@ -238,9 +282,13 @@ func newPathBuilder(topLevelDoc string) pathBuilder {
func (p pathBuilder) appendStructField(field reflect.StructField) pathBuilder {
switch {
case field.Tag.Get("json") != "":
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("json")))
// cut off omitempty or other options
jsonTagName, _, _ := strings.Cut(field.Tag.Get("json"), ",")
p.buf = append(p.buf, fmt.Sprintf(".%s", jsonTagName))
case field.Tag.Get("yaml") != "":
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("yaml")))
// cut off omitempty or other options
yamlTagName, _, _ := strings.Cut(field.Tag.Get("yaml"), ",")
p.buf = append(p.buf, fmt.Sprintf(".%s", yamlTagName))
default:
p.buf = append(p.buf, fmt.Sprintf(".%s", field.Name))
}

View File

@ -15,6 +15,37 @@ import (
"github.com/stretchr/testify/require"
)
func TestErrorFormatting(t *testing.T) {
err := &TreeError{
path: "path",
err: fmt.Errorf("error"),
children: []*TreeError{},
}
assert.Equal(t, "validating path: error", err.Error())
err.children = append(err.children, &TreeError{
path: "child",
err: fmt.Errorf("child error"),
children: []*TreeError{},
})
assert.Equal(t, "validating path: error\n validating child: child error", err.Error())
err.children = append(err.children, &TreeError{
path: "child2",
err: fmt.Errorf("child2 error"),
children: []*TreeError{
{
path: "child2child",
err: fmt.Errorf("child2child error"),
children: []*TreeError{},
},
},
})
assert.Equal(t, "validating path: error\n validating child: child error\n validating child2: child2 error\n validating child2child: child2child error", err.Error())
}
// Tests for primitive / shallow fields
func TestNewValidationErrorSingleField(t *testing.T) {
@ -24,7 +55,7 @@ func TestNewValidationErrorSingleField(t *testing.T) {
}
doc, field := references(t, st, &st.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.otherField: %s", assert.AnError))
}
@ -37,7 +68,7 @@ func TestNewValidationErrorSingleFieldPtr(t *testing.T) {
}
doc, field := references(t, st, &st.PointerField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.pointerField: %s", assert.AnError))
}
@ -51,7 +82,7 @@ func TestNewValidationErrorSingleFieldDoublePtr(t *testing.T) {
}
doc, field := references(t, st, &st.DoublePointerField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.doublePointerField: %s", assert.AnError))
}
@ -66,7 +97,7 @@ func TestNewValidationErrorSingleFieldInexistent(t *testing.T) {
inexistentField := 123
doc, field := references(t, st, &inexistentField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot find path to field: cannot traverse anymore")
}
@ -84,7 +115,7 @@ func TestNewValidationErrorNestedField(t *testing.T) {
}
doc, field := references(t, st, &st.NestedField.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.otherField: %s", assert.AnError))
@ -102,7 +133,7 @@ func TestNewValidationErrorPointerInNestedField(t *testing.T) {
}
doc, field := references(t, st, &st.NestedField.PointerField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.pointerField: %s", assert.AnError))
@ -123,7 +154,7 @@ func TestNewValidationErrorNestedFieldPtr(t *testing.T) {
}
doc, field := references(t, st, &st.NestedPointerField.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.otherField: %s", assert.AnError))
@ -144,7 +175,7 @@ func TestNewValidationErrorNestedNestedField(t *testing.T) {
}
doc, field := references(t, st, &st.NestedField.NestedField.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedField.otherField: %s", assert.AnError))
@ -165,7 +196,7 @@ func TestNewValidationErrorNestedNestedFieldPtr(t *testing.T) {
}
doc, field := references(t, st, &st.NestedField.NestedPointerField.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedPointerField.otherField: %s", assert.AnError))
@ -186,7 +217,7 @@ func TestNewValidationErrorNestedPtrNestedFieldPtr(t *testing.T) {
}
doc, field := references(t, st, &st.NestedPointerField.NestedPointerField.OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.nestedPointerField.otherField: %s", assert.AnError))
@ -200,7 +231,7 @@ func TestNewValidationErrorPrimitiveSlice(t *testing.T) {
}
doc, field := references(t, st, &st.PrimitiveSlice[1], "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSlice[1]: %s", assert.AnError))
@ -212,7 +243,7 @@ func TestNewValidationErrorPrimitiveArray(t *testing.T) {
}
doc, field := references(t, st, &st.PrimitiveArray[1], "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveArray[1]: %s", assert.AnError))
@ -233,7 +264,7 @@ func TestNewValidationErrorStructSlice(t *testing.T) {
}
doc, field := references(t, st, &st.StructSlice[1].OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structSlice[1].otherField: %s", assert.AnError))
@ -254,7 +285,7 @@ func TestNewValidationErrorStructArray(t *testing.T) {
}
doc, field := references(t, st, &st.StructArray[1].OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structArray[1].otherField: %s", assert.AnError))
@ -275,7 +306,7 @@ func TestNewValidationErrorStructPointerSlice(t *testing.T) {
}
doc, field := references(t, st, &st.StructPointerSlice[1].OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerSlice[1].otherField: %s", assert.AnError))
@ -296,7 +327,7 @@ func TestNewValidationErrorStructPointerArray(t *testing.T) {
}
doc, field := references(t, st, &st.StructPointerArray[1].OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerArray[1].otherField: %s", assert.AnError))
@ -311,7 +342,7 @@ func TestNewValidationErrorPrimitiveSliceSlice(t *testing.T) {
}
doc, field := references(t, st, &st.PrimitiveSliceSlice[1][1], "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSliceSlice[1][1]: %s", assert.AnError))
@ -328,7 +359,7 @@ func TestNewValidationErrorPrimitiveMap(t *testing.T) {
}
doc, field := references(t, st, &st.PrimitiveMap, "ghi")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.primitiveMap[\"ghi\"]: %s", assert.AnError))
@ -349,7 +380,7 @@ func TestNewValidationErrorStructPointerMap(t *testing.T) {
}
doc, field := references(t, st, &st.StructPointerMap["ghi"].OtherField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.structPointerMap[\"ghi\"].otherField: %s", assert.AnError))
@ -368,7 +399,7 @@ func TestNewValidationErrorNestedPrimitiveMap(t *testing.T) {
}
doc, field := references(t, st, st.NestedPointerMap["jkl"], "mno")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
t.Log(err)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.nestedPointerMap[\"jkl\"][\"mno\"]: %s", assert.AnError))
@ -383,7 +414,7 @@ func TestNewValidationErrorTopLevelIsNeedle(t *testing.T) {
}
doc, field := references(t, st, st, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc: %s", assert.AnError))
}
@ -396,7 +427,7 @@ func TestNewValidationErrorUntaggedField(t *testing.T) {
}
doc, field := references(t, st, &st.NoTagField, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.NoTagField: %s", assert.AnError))
}
@ -410,7 +441,7 @@ func TestNewValidationErrorOnlyYamlTaggedField(t *testing.T) {
}
doc, field := references(t, st, &st.OnlyYamlKey, "")
err := newError(doc, field, assert.AnError)
err := newTraceError(doc, field, assert.AnError)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.onlyYamlKey: %s", assert.AnError))
}

View File

@ -11,7 +11,19 @@ It validates documents that specify a set of constraints on their content.
*/
package validation
import "errors"
import (
"errors"
)
// ErrStrategy is the strategy to use when encountering an error during validation.
type ErrStrategy int
const (
// EvaluateAll continues evaluating all constraints even if one is not satisfied.
EvaluateAll ErrStrategy = iota
// FailFast stops validation on the first error.
FailFast
)
// NewValidator creates a new Validator.
func NewValidator() *Validator {
@ -24,21 +36,31 @@ type Validator struct{}
// Validatable is implemented by documents that can be validated.
// It returns a list of constraints that must be satisfied for the document to be valid.
type Validatable interface {
Constraints() []Constraint
Constraints() []*Constraint
}
// ValidateOptions are the options to use when validating a document.
type ValidateOptions struct {
// FailFast stops validation on the first error.
FailFast bool
// ErrStrategy is the strategy to use when encountering an error during validation.
ErrStrategy ErrStrategy
// OverrideConstraints overrides the constraints to use for validation.
// If nil, the constraints returned by the document are used.
OverrideConstraints func() []*Constraint
}
// Validate validates a document using the given options.
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error {
var constraints func() []*Constraint
if opts.OverrideConstraints != nil {
constraints = opts.OverrideConstraints
} else {
constraints = doc.Constraints
}
var retErr error
for _, c := range doc.Constraints() {
for _, c := range constraints() {
if err := c.Satisfied(); err != nil {
if opts.FailFast {
if opts.ErrStrategy == FailFast {
return err
}
retErr = errors.Join(retErr, err)

View File

@ -14,34 +14,39 @@ import (
"github.com/stretchr/testify/require"
)
var validDoc = func() *exampleDoc {
return &exampleDoc{
StrField: "abc",
NumField: 42,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
OneOfField: "one",
OrLeftField: "left",
OrRightField: "right",
AndLeftField: "left",
AndRightField: "right",
}
}
func TestValidate(t *testing.T) {
testCases := map[string]struct {
doc Validatable
doc func() *exampleDoc
opts ValidateOptions
wantErr bool
errAssertion func(*assert.Assertions, error) bool
}{
"valid": {
doc: &exampleDoc{
StrField: "abc",
NumField: 42,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
},
doc: validDoc,
opts: ValidateOptions{},
},
"strField is not abc": {
doc: &exampleDoc{
StrField: "def",
NumField: 42,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
doc: func() *exampleDoc {
doc := validDoc()
doc.StrField = "def"
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
@ -50,14 +55,10 @@ func TestValidate(t *testing.T) {
opts: ValidateOptions{},
},
"numField is not 42": {
doc: &exampleDoc{
StrField: "abc",
NumField: 43,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
doc: func() *exampleDoc {
doc := validDoc()
doc.NumField = 43
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
@ -65,14 +66,11 @@ func TestValidate(t *testing.T) {
},
},
"multiple errors": {
doc: &exampleDoc{
StrField: "def",
NumField: 43,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
doc: func() *exampleDoc {
doc := validDoc()
doc.StrField = "def"
doc.NumField = 43
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
@ -82,75 +80,108 @@ func TestValidate(t *testing.T) {
opts: ValidateOptions{},
},
"multiple errors, fail fast": {
doc: &exampleDoc{
StrField: "def",
NumField: 43,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
doc: func() *exampleDoc {
doc := validDoc()
doc.StrField = "def"
doc.NumField = 43
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc")
},
opts: ValidateOptions{
FailFast: true,
ErrStrategy: FailFast,
},
},
"map field is not empty": {
doc: &exampleDoc{
StrField: "abc",
NumField: 42,
MapField: &map[string]string{
doc: func() *exampleDoc {
doc := validDoc()
doc.MapField = &map[string]string{
"empty": "haha!",
},
NotEmptyField: "certainly not.",
MatchRegexField: "abc",
}
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "validating exampleDoc.mapField[\"empty\"]: haha! must be empty")
},
opts: ValidateOptions{
FailFast: true,
ErrStrategy: FailFast,
},
},
"empty field is not empty": {
doc: &exampleDoc{
StrField: "abc",
NumField: 42,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "",
MatchRegexField: "abc",
"not empty field is empty": {
doc: func() *exampleDoc {
doc := validDoc()
doc.NotEmptyField = ""
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty")
},
opts: ValidateOptions{
FailFast: true,
ErrStrategy: FailFast,
},
},
"regex doesnt match": {
doc: &exampleDoc{
StrField: "abc",
NumField: 42,
MapField: &map[string]string{
"empty": "",
},
NotEmptyField: "certainly not!",
MatchRegexField: "dontmatch",
doc: func() *exampleDoc {
doc := validDoc()
doc.MatchRegexField = "dontmatch"
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "validating exampleDoc.matchRegexField: dontmatch must match the pattern ^a.c$")
},
opts: ValidateOptions{
FailFast: true,
ErrStrategy: FailFast,
},
},
"field is not in 'oneof' values": {
doc: func() *exampleDoc {
doc := validDoc()
doc.OneOfField = "not in oneof"
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "validating exampleDoc.oneOfField: not in oneof must be one of [one two three]")
},
opts: ValidateOptions{
ErrStrategy: FailFast,
},
},
"'or' violated": {
doc: func() *exampleDoc {
doc := validDoc()
doc.OrLeftField = "not left"
doc.OrRightField = "not right"
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "at least one of the constraints must be satisfied:") &&
assert.Contains(err.Error(), "validating exampleDoc.orLeftField: not left must be equal to left") &&
assert.Contains(err.Error(), "validating exampleDoc.orRightField: not right must be equal to right")
},
opts: ValidateOptions{
ErrStrategy: FailFast,
},
},
"'and' violated": {
doc: func() *exampleDoc {
doc := validDoc()
doc.AndRightField = "not right"
return doc
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "all of the constraints must be satisfied:") &&
assert.Contains(err.Error(), "validating exampleDoc.andRightField: not right must be equal to right")
},
opts: ValidateOptions{
ErrStrategy: FailFast,
},
},
}
@ -160,7 +191,7 @@ func TestValidate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
err := NewValidator().Validate(tc.doc, tc.opts)
err := NewValidator().Validate(tc.doc(), tc.opts)
if tc.wantErr {
require.Error(err)
if !tc.errAssertion(assert, err) {
@ -179,13 +210,18 @@ type exampleDoc struct {
MapField *map[string]string `json:"mapField"`
NotEmptyField string `json:"notEmptyField"`
MatchRegexField string `json:"matchRegexField"`
OneOfField string `json:"oneOfField"`
OrLeftField string `json:"orLeftField"`
OrRightField string `json:"orRightField"`
AndLeftField string `json:"andLeftField"`
AndRightField string `json:"andRightField"`
}
// Constraints implements the Validatable interface.
func (d *exampleDoc) Constraints() []Constraint {
func (d *exampleDoc) Constraints() []*Constraint {
mapField := *(d.MapField)
return []Constraint{
return []*Constraint{
d.strFieldNeedsToBeAbc().
WithFieldTrace(d, &d.StrField),
Equal(d.NumField, 42).
@ -196,17 +232,95 @@ func (d *exampleDoc) Constraints() []Constraint {
WithFieldTrace(d, &d.NotEmptyField),
MatchRegex(d.MatchRegexField, "^a.c$").
WithFieldTrace(d, &d.MatchRegexField),
OneOf(d.OneOfField, []string{"one", "two", "three"}).
WithFieldTrace(d, &d.OneOfField),
Or(
Equal(d.OrLeftField, "left").
WithFieldTrace(d, &d.OrLeftField),
Equal(d.OrRightField, "right").
WithFieldTrace(d, &d.OrRightField),
),
And(
EvaluateAll,
Equal(d.AndLeftField, "left").
WithFieldTrace(d, &d.AndLeftField),
Equal(d.AndRightField, "right").
WithFieldTrace(d, &d.AndRightField),
),
}
}
// StrFieldNeedsToBeAbc is an example for a custom constraint.
func (d *exampleDoc) strFieldNeedsToBeAbc() *Constraint {
return &Constraint{
Satisfied: func() error {
Satisfied: func() *TreeError {
if d.StrField != "abc" {
return fmt.Errorf("%s must be abc", d.StrField)
return NewErrorTree(
fmt.Errorf("%s must be abc", d.StrField),
)
}
return nil
},
}
}
func TestOverrideConstraints(t *testing.T) {
overrideConstraints := func(t *testing.T, wantCalled bool) func() []*Constraint {
return func() []*Constraint {
if !wantCalled {
t.Fatal("overrideConstraints should not be called")
}
return []*Constraint{}
}
}
testCases := map[string]struct {
doc exampleDocToOverride
overrideFunc func() []*Constraint
wantOverrideCalled bool
wantErr bool
}{
"override constraints": {
doc: exampleDocToOverride{},
overrideFunc: overrideConstraints(t, true),
wantOverrideCalled: true,
},
"do not override constraints": {
doc: exampleDocToOverride{
calledDocConstraints: true,
},
overrideFunc: nil,
wantOverrideCalled: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
validator := NewValidator()
err := validator.Validate(&tc.doc, ValidateOptions{
OverrideConstraints: tc.overrideFunc,
})
if tc.wantErr {
require.Error(err)
} else {
require.NoError(err)
if tc.wantOverrideCalled {
assert.Equal(tc.doc.calledDocConstraints, false)
}
}
})
}
}
type exampleDocToOverride struct {
calledDocConstraints bool
}
func (d *exampleDocToOverride) Constraints() []*Constraint {
d.calledDocConstraints = true
return []*Constraint{}
}