diff --git a/internal/validation/BUILD.bazel b/internal/validation/BUILD.bazel new file mode 100644 index 000000000..70b508991 --- /dev/null +++ b/internal/validation/BUILD.bazel @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") + +go_library( + name = "validation", + srcs = [ + "constraints.go", + "errors.go", + "validation.go", + ], + importpath = "github.com/edgelesssys/constellation/v2/internal/validation", + visibility = ["//:__subpackages__"], +) + +go_test( + name = "validation_test", + srcs = [ + "errors_test.go", + "validation_test.go", + ], + embed = [":validation"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/internal/validation/constraints.go b/internal/validation/constraints.go new file mode 100644 index 000000000..01c299135 --- /dev/null +++ b/internal/validation/constraints.go @@ -0,0 +1,153 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package validation + +import ( + "fmt" + "reflect" + "regexp" +) + +// Constraint is a constraint on a document or a field of a document. +type Constraint struct { + // Satisfied returns no error if the constraint is satisfied. + // Otherwise, it returns the reason why the constraint is not satisfied. + Satisfied func() error +} + +/* +WithFieldTrace adds a well-formatted trace to the field to the error message +shown when the constraint is not satisfied. Both "doc" and "field" must be pointers: + - "doc" must be a pointer to the top level document + - "field" must be a pointer to the field to be validated + +Example for a non-pointer field: + + Equal(d.IntField, 42).WithFieldTrace(d, &d.IntField) + +Example for a pointer field: + + NotEmpty(d.StrPtrField).WithFieldTrace(d, d.StrPtrField) + +Due to Go's addressability limititations regarding maps, if a map field is +to be validated, WithMapFieldTrace must be used instead of WithFieldTrace. +*/ +func (c *Constraint) WithFieldTrace(doc any, field any) Constraint { + // we only want to dereference the needle once to dereference the pointer + // used to pass it to the function without losing reference to it, as the + // needle could be an arbitrarily long chain of pointers. The same + // applies to the haystack. + derefedField := pointerDeref(reflect.ValueOf(field)) + fieldRef := referenceableValue{ + value: derefedField, + addr: derefedField.UnsafeAddr(), + _type: derefedField.Type(), + } + derefedDoc := pointerDeref(reflect.ValueOf(doc)) + docRef := referenceableValue{ + value: derefedDoc, + addr: derefedDoc.UnsafeAddr(), + _type: derefedDoc.Type(), + } + return c.withTrace(docRef, fieldRef) +} + +/* +WithMapFieldTrace adds a well-formatted trace to the map field to the error message +shown when the constraint is not satisfied. Both "doc" and "field" must be pointers: + - "doc" must be a pointer to the top level document + - "field" must be a pointer to the map containing the field to be validated + - "mapKey" must be the key of the field to be validated in the map pointed to by "field" + +Example: + + Equal(d.IntField, 42).WithMapFieldTrace(d, &d.MapField, mapKey) + +For non-map fields, WithFieldTrace should be used instead of WithMapFieldTrace. +*/ +func (c *Constraint) WithMapFieldTrace(doc any, field any, mapKey string) Constraint { + // we only want to dereference the needle once to dereference the pointer + // used to pass it to the function without losing reference to it, as the + // needle could be an arbitrarily long chain of pointers. The same + // applies to the haystack. + derefedField := pointerDeref(reflect.ValueOf(field)) + fieldRef := referenceableValue{ + value: derefedField, + addr: derefedField.UnsafeAddr(), + _type: derefedField.Type(), + mapKey: mapKey, + } + derefedDoc := pointerDeref(reflect.ValueOf(doc)) + docRef := referenceableValue{ + value: derefedDoc, + addr: derefedDoc.UnsafeAddr(), + _type: derefedDoc.Type(), + } + return c.withTrace(docRef, fieldRef) +} + +// withTrace wraps the constraint's error message with a well-formatted trace. +func (c *Constraint) withTrace(docRef, fieldRef referenceableValue) Constraint { + return Constraint{ + Satisfied: func() error { + if err := c.Satisfied(); err != nil { + return newError(docRef, fieldRef, err) + } + return nil + }, + } +} + +// MatchRegex is a constraint that if s matches regex. +func MatchRegex(s string, regex string) *Constraint { + return &Constraint{ + Satisfied: func() error { + if !regexp.MustCompile(regex).MatchString(s) { + return fmt.Errorf("%s must match the pattern %s", s, regex) + } + return nil + }, + } +} + +// Equal is a constraint that if s is equal to t. +func Equal[T comparable](s T, t T) *Constraint { + return &Constraint{ + Satisfied: func() error { + if s != t { + return fmt.Errorf("%v must be equal to %v", s, t) + } + return nil + }, + } +} + +// NotEmpty is a constraint that if s is not empty. +func NotEmpty[T comparable](s T) *Constraint { + return &Constraint{ + Satisfied: func() error { + var zero T + if s == zero { + return fmt.Errorf("%v must not be empty", s) + } + return nil + }, + } +} + +// Empty is a constraint that if s is empty. +func Empty[T comparable](s T) *Constraint { + return &Constraint{ + Satisfied: func() error { + var zero T + if s != zero { + return fmt.Errorf("%v must be empty", s) + } + return nil + }, + } +} diff --git a/internal/validation/errors.go b/internal/validation/errors.go new file mode 100644 index 000000000..05fc74a7f --- /dev/null +++ b/internal/validation/errors.go @@ -0,0 +1,269 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package validation + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// Error is returned when a document is not valid. +type Error struct { + Path string + Err error +} + +/* +newError creates a new validation Error. + +To find the path to the exported field that failed validation, it traverses "doc" +recursively until it finds a field in "doc" that matches the reference to "field". +*/ +func newError(doc, field referenceableValue, errMsg error) *Error { + // traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found + path, err := traverse(doc, field, newPathBuilder(doc._type.Name())) + if err != nil { + return &Error{ + Path: "unknown", + Err: fmt.Errorf("cannot find path to field: %w. original error: %w", err, errMsg), + } + } + + return &Error{ + Path: path, + Err: errMsg, + } +} + +// Error implements the error interface. +func (e *Error) Error() string { + return fmt.Sprintf("validating %s: %s", e.Path, e.Err) +} + +// Unwrap implements the error interface. +func (e *Error) Unwrap() error { + return e.Err +} + +/* +traverse "haystack" recursively until it finds a field that matches +the reference saved in "needle", while building a pseudo-JSONPath to the field. + +If it traverses a level down, it appends the name of the struct tag +or another entity like array index or map field to path. + +When a field matches the reference to the given field, it returns the +path to the field. +*/ +func traverse(haystack referenceableValue, needle referenceableValue, path pathBuilder) (string, error) { + // recursion anchor: doc is the field we are looking for. + // Join the path and return. + if foundNeedle(haystack, needle) { + return path.string(), nil + } + + kind := haystack._type.Kind() + switch kind { + case reflect.Struct: + // Traverse all visible struct fields. + for _, field := range reflect.VisibleFields(haystack._type) { + // skip unexported fields + if !field.IsExported() { + continue + } + + fieldVal := recPointerDeref(haystack.value.FieldByName(field.Name)) + if isNilPtrOrInvalid(fieldVal) { + continue + } + + fieldAddr := haystack.addr + field.Offset + newHaystack := referenceableValue{ + value: fieldVal, + addr: fieldVal.UnsafeAddr(), + _type: fieldVal.Type(), + } + if canTraverse(fieldVal) { + // When a field is not the needle and cannot be traversed further, + // a errCannotTraverse is returned. Therefore, we only want to handle + // the case where the field is the needle. + if path, err := traverse(newHaystack, needle, path.appendStructField(field)); err == nil { + return path, nil + } + } + if foundNeedle(referenceableValue{addr: fieldAddr, _type: field.Type}, needle) { + return path.appendStructField(field).string(), nil + } + } + case reflect.Slice, reflect.Array: + // Traverse slice / Array elements + for i := 0; i < haystack.value.Len(); i++ { + // see struct case + itemVal := recPointerDeref(haystack.value.Index(i)) + if isNilPtrOrInvalid(itemVal) { + continue + } + newHaystack := referenceableValue{ + value: itemVal, + addr: itemVal.UnsafeAddr(), + _type: itemVal.Type(), + } + if canTraverse(itemVal) { + if path, err := traverse(newHaystack, needle, path.appendArrayIndex(i)); err == nil { + return path, nil + } + } + if foundNeedle(newHaystack, needle) { + return path.appendArrayIndex(i).string(), nil + } + } + case reflect.Map: + // Traverse map elements + iter := haystack.value.MapRange() + for iter.Next() { + // see struct case + mapKey := iter.Key().String() + mapVal := recPointerDeref(iter.Value()) + if isNilPtrOrInvalid(mapVal) { + continue + } + if canTraverse(mapVal) { + newHaystack := referenceableValue{ + value: mapVal, + addr: mapVal.UnsafeAddr(), + _type: mapVal.Type(), + mapKey: mapKey, + } + if path, err := traverse(newHaystack, needle, path.appendMapKey(mapKey)); err == nil { + return path, nil + } + } + // check if reference to map is the needle and the map key matches + if foundNeedle(referenceableValue{addr: haystack.addr, _type: haystack._type, mapKey: mapKey}, needle) { + return path.appendMapKey(mapKey).string(), nil + } + } + } + + // Primitive type, but not the value we are looking for. + return "", errCannotTraverse +} + +// referenceableValue is a type that can be passed as any (thus being copied) without losing the reference to the actual value. +type referenceableValue struct { + value reflect.Value + _type reflect.Type + mapKey string // special case for map values, which are not addressable + addr uintptr +} + +// errCannotTraverse is returned when a field cannot be traversed further. +var errCannotTraverse = errors.New("cannot traverse anymore") + +// recPointerDeref recursively dereferences pointers and unpacks interfaces until a non-pointer value is found. +func recPointerDeref(val reflect.Value) reflect.Value { + switch val.Kind() { + case reflect.Ptr, reflect.UnsafePointer, reflect.Interface: + return recPointerDeref(val.Elem()) + } + return val +} + +// pointerDeref dereferences pointers and unpacks interfaces. +// If the value is not a pointer, it is returned unchanged. +func pointerDeref(val reflect.Value) reflect.Value { + switch val.Kind() { + case reflect.Ptr, reflect.UnsafePointer, reflect.Interface: + return val.Elem() + } + return val +} + +/* +canTraverse whether a value can be further traversed. + +For pointer types, false is returned. +*/ +func canTraverse(v reflect.Value) bool { + switch v.Kind() { + case reflect.Struct, reflect.Slice, reflect.Array, reflect.Map: + return true + } + return false +} + +// isNilPtrOrInvalid returns true if a value is a nil pointer or if the value is of an invalid kind. +func isNilPtrOrInvalid(v reflect.Value) bool { + switch v.Kind() { + case reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice, reflect.Map: + return v.IsNil() + case reflect.Invalid: + return true + } + return false +} + +/* +foundNeedle returns whether the given value is the needle. + +It does so by comparing the address and type of the value to the address and type of the needle. +The comparison of types is necessary because the first value of a struct has the same address as the struct itself. +*/ +func foundNeedle(haystack, needle referenceableValue) bool { + return haystack.addr == needle.addr && + haystack._type == needle._type && + haystack.mapKey == needle.mapKey +} + +// pathBuilder is a helper to build a field path. +type pathBuilder struct { + buf []string // slice can be copied by value when its non-zero, contrary to a strings.Builder +} + +// newPathBuilder creates a new pathBuilder from the identifier of a top level document. +func newPathBuilder(topLevelDoc string) pathBuilder { + return pathBuilder{ + buf: []string{topLevelDoc}, + } +} + +// appendStructField appends the JSON / YAML struct tag of a field to the path. +// If no struct tag is present, the field name is used. +func (p pathBuilder) appendStructField(field reflect.StructField) pathBuilder { + switch { + case field.Tag.Get("json") != "": + p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("json"))) + case field.Tag.Get("yaml") != "": + p.buf = append(p.buf, fmt.Sprintf(".%s", field.Tag.Get("yaml"))) + default: + p.buf = append(p.buf, fmt.Sprintf(".%s", field.Name)) + } + return p +} + +// appendArrayIndex appends the index of an array to the path. +func (p pathBuilder) appendArrayIndex(i int) pathBuilder { + p.buf = append(p.buf, fmt.Sprintf("[%d]", i)) + return p +} + +// appendMapKey appends the key of a map to the path. +func (p pathBuilder) appendMapKey(k string) pathBuilder { + p.buf = append(p.buf, fmt.Sprintf("[\"%s\"]", k)) + return p +} + +// string returns the path. +func (p pathBuilder) string() string { + // Remove struct tag prefix + return strings.TrimPrefix( + strings.Join(p.buf, ""), + ".", + ) +} diff --git a/internal/validation/errors_test.go b/internal/validation/errors_test.go new file mode 100644 index 000000000..6065dc77a --- /dev/null +++ b/internal/validation/errors_test.go @@ -0,0 +1,476 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package validation + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Tests for primitive / shallow fields + +func TestNewValidationErrorSingleField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + } + + doc, field := references(t, st, &st.OtherField, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorSingleFieldPtr(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + PointerField: new(int), + } + + doc, field := references(t, st, &st.PointerField, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.pointerField: %s", assert.AnError)) +} + +func TestNewValidationErrorSingleFieldDoublePtr(t *testing.T) { + intp := new(int) + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + DoublePointerField: &intp, + } + + doc, field := references(t, st, &st.DoublePointerField, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.doublePointerField: %s", assert.AnError)) +} + +func TestNewValidationErrorSingleFieldInexistent(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + PointerField: new(int), + } + + inexistentField := 123 + + doc, field := references(t, st, &inexistentField, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot find path to field: cannot traverse anymore") +} + +// Tests for nested structs + +func TestNewValidationErrorNestedField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedField: nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + } + + doc, field := references(t, st, &st.NestedField.OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorPointerInNestedField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedField: nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + PointerField: new(int), + }, + } + + doc, field := references(t, st, &st.NestedField.PointerField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.pointerField: %s", assert.AnError)) +} + +func TestNewValidationErrorNestedFieldPtr(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedField: nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + NestedPointerField: &nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + } + + doc, field := references(t, st, &st.NestedPointerField.OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorNestedNestedField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedField: nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + NestedField: nestedNestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + }, + } + + doc, field := references(t, st, &st.NestedField.NestedField.OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedField.otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorNestedNestedFieldPtr(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedField: nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + NestedPointerField: &nestedNestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + }, + } + + doc, field := references(t, st, &st.NestedField.NestedPointerField.OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedField.nestedPointerField.otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorNestedPtrNestedFieldPtr(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NestedPointerField: &nestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + NestedPointerField: &nestedNestederrorTestDoc{ + ExportedField: "nested", + OtherField: 123, + }, + }, + } + + doc, field := references(t, st, &st.NestedPointerField.NestedPointerField.OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.nestedPointerField.nestedPointerField.otherField: %s", assert.AnError)) +} + +// Tests for slices / arrays + +func TestNewValidationErrorPrimitiveSlice(t *testing.T) { + st := &sliceErrorTestDoc{ + PrimitiveSlice: []string{"abc", "def"}, + } + + doc, field := references(t, st, &st.PrimitiveSlice[1], "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSlice[1]: %s", assert.AnError)) +} + +func TestNewValidationErrorPrimitiveArray(t *testing.T) { + st := &sliceErrorTestDoc{ + PrimitiveArray: [3]int{1, 2, 3}, + } + + doc, field := references(t, st, &st.PrimitiveArray[1], "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveArray[1]: %s", assert.AnError)) +} + +func TestNewValidationErrorStructSlice(t *testing.T) { + st := &sliceErrorTestDoc{ + StructSlice: []errorTestDoc{ + { + ExportedField: "abc", + OtherField: 123, + }, + { + ExportedField: "def", + OtherField: 456, + }, + }, + } + + doc, field := references(t, st, &st.StructSlice[1].OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structSlice[1].otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorStructArray(t *testing.T) { + st := &sliceErrorTestDoc{ + StructArray: [3]errorTestDoc{ + { + ExportedField: "abc", + OtherField: 123, + }, + { + ExportedField: "def", + OtherField: 456, + }, + }, + } + + doc, field := references(t, st, &st.StructArray[1].OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structArray[1].otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorStructPointerSlice(t *testing.T) { + st := &sliceErrorTestDoc{ + StructPointerSlice: []*errorTestDoc{ + { + ExportedField: "abc", + OtherField: 123, + }, + { + ExportedField: "def", + OtherField: 456, + }, + }, + } + + doc, field := references(t, st, &st.StructPointerSlice[1].OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerSlice[1].otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorStructPointerArray(t *testing.T) { + st := &sliceErrorTestDoc{ + StructPointerArray: [3]*errorTestDoc{ + { + ExportedField: "abc", + OtherField: 123, + }, + { + ExportedField: "def", + OtherField: 456, + }, + }, + } + + doc, field := references(t, st, &st.StructPointerArray[1].OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.structPointerArray[1].otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorPrimitiveSliceSlice(t *testing.T) { + st := &sliceErrorTestDoc{ + PrimitiveSliceSlice: [][]string{ + {"abc", "def"}, + {"ghi", "jkl"}, + }, + } + + doc, field := references(t, st, &st.PrimitiveSliceSlice[1][1], "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating sliceErrorTestDoc.primitiveSliceSlice[1][1]: %s", assert.AnError)) +} + +// Tests for maps + +func TestNewValidationErrorPrimitiveMap(t *testing.T) { + st := &mapErrorTestDoc{ + PrimitiveMap: map[string]string{ + "abc": "def", + "ghi": "jkl", + }, + } + + doc, field := references(t, st, &st.PrimitiveMap, "ghi") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.primitiveMap[\"ghi\"]: %s", assert.AnError)) +} + +func TestNewValidationErrorStructPointerMap(t *testing.T) { + st := &mapErrorTestDoc{ + StructPointerMap: map[string]*errorTestDoc{ + "abc": { + ExportedField: "abc", + OtherField: 123, + }, + "ghi": { + ExportedField: "ghi", + OtherField: 456, + }, + }, + } + + doc, field := references(t, st, &st.StructPointerMap["ghi"].OtherField, "") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.structPointerMap[\"ghi\"].otherField: %s", assert.AnError)) +} + +func TestNewValidationErrorNestedPrimitiveMap(t *testing.T) { + st := &mapErrorTestDoc{ + NestedPointerMap: map[string]*map[string]string{ + "abc": { + "def": "ghi", + }, + "jkl": { + "mno": "pqr", + }, + }, + } + + doc, field := references(t, st, st.NestedPointerMap["jkl"], "mno") + err := newError(doc, field, assert.AnError) + t.Log(err) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating mapErrorTestDoc.nestedPointerMap[\"jkl\"][\"mno\"]: %s", assert.AnError)) +} + +// Special cases + +func TestNewValidationErrorTopLevelIsNeedle(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + } + + doc, field := references(t, st, st, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc: %s", assert.AnError)) +} + +func TestNewValidationErrorUntaggedField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NoTagField: 123, + } + + doc, field := references(t, st, &st.NoTagField, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.NoTagField: %s", assert.AnError)) +} + +func TestNewValidationErrorOnlyYamlTaggedField(t *testing.T) { + st := &errorTestDoc{ + ExportedField: "abc", + OtherField: 42, + NoTagField: 123, + OnlyYamlKey: "abc", + } + + doc, field := references(t, st, &st.OnlyYamlKey, "") + err := newError(doc, field, assert.AnError) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("validating errorTestDoc.onlyYamlKey: %s", assert.AnError)) +} + +type errorTestDoc struct { + ExportedField string `json:"exportedField" yaml:"exportedField"` + OtherField int `json:"otherField" yaml:"otherField"` + PointerField *int `json:"pointerField" yaml:"pointerField"` + DoublePointerField **int `json:"doublePointerField" yaml:"doublePointerField"` + NestedField nestederrorTestDoc `json:"nestedField" yaml:"nestedField"` + NestedPointerField *nestederrorTestDoc `json:"nestedPointerField" yaml:"nestedPointerField"` + NoTagField int + OnlyYamlKey string `yaml:"onlyYamlKey"` +} + +type nestederrorTestDoc struct { + ExportedField string `json:"exportedField" yaml:"exportedField"` + OtherField int `json:"otherField" yaml:"otherField"` + PointerField *int `json:"pointerField" yaml:"pointerField"` + NestedField nestedNestederrorTestDoc `json:"nestedField" yaml:"nestedField"` + NestedPointerField *nestedNestederrorTestDoc `json:"nestedPointerField" yaml:"nestedPointerField"` +} + +type nestedNestederrorTestDoc struct { + ExportedField string `json:"exportedField" yaml:"exportedField"` + OtherField int `json:"otherField" yaml:"otherField"` + PointerField *int `json:"pointerField" yaml:"pointerField"` +} + +type sliceErrorTestDoc struct { + PrimitiveSlice []string `json:"primitiveSlice" yaml:"primitiveSlice"` + PrimitiveArray [3]int `json:"primitiveArray" yaml:"primitiveArray"` + StructSlice []errorTestDoc `json:"structSlice" yaml:"structSlice"` + StructArray [3]errorTestDoc `json:"structArray" yaml:"structArray"` + StructPointerSlice []*errorTestDoc `json:"structPointerSlice" yaml:"structPointerSlice"` + StructPointerArray [3]*errorTestDoc `json:"structPointerArray" yaml:"structPointerArray"` + PrimitiveSliceSlice [][]string `json:"primitiveSliceSlice" yaml:"primitiveSliceSlice"` +} + +type mapErrorTestDoc struct { + PrimitiveMap map[string]string `json:"primitiveMap" yaml:"primitiveMap"` + StructPointerMap map[string]*errorTestDoc `json:"structPointerMap" yaml:"structPointerMap"` + NestedPointerMap map[string]*map[string]string `json:"nestedPointerMap" yaml:"nestedPointerMap"` +} + +// references returns referenceableValues for the given doc and field for testing purposes. +func references(t *testing.T, doc, field any, mapKey string) (haystack, needle referenceableValue) { + t.Helper() + derefedField := pointerDeref(reflect.ValueOf(field)) + fieldRef := referenceableValue{ + value: derefedField, + addr: derefedField.UnsafeAddr(), + _type: derefedField.Type(), + mapKey: mapKey, + } + derefedDoc := pointerDeref(reflect.ValueOf(doc)) + docRef := referenceableValue{ + value: derefedDoc, + addr: derefedDoc.UnsafeAddr(), + _type: derefedDoc.Type(), + } + return docRef, fieldRef +} diff --git a/internal/validation/validation.go b/internal/validation/validation.go new file mode 100644 index 000000000..84b37aa40 --- /dev/null +++ b/internal/validation/validation.go @@ -0,0 +1,48 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package validation provides a unified document validation interface for use within the Constellation CLI. + +It validates documents that specify a set of constraints on their content. +*/ +package validation + +import "errors" + +// NewValidator creates a new Validator. +func NewValidator() *Validator { + return &Validator{} +} + +// Validator validates documents. +type Validator struct{} + +// Validatable is implemented by documents that can be validated. +// It returns a list of constraints that must be satisfied for the document to be valid. +type Validatable interface { + Constraints() []Constraint +} + +// ValidateOptions are the options to use when validating a document. +type ValidateOptions struct { + // FailFast stops validation on the first error. + FailFast bool +} + +// Validate validates a document using the given options. +func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error { + var retErr error + for _, c := range doc.Constraints() { + if err := c.Satisfied(); err != nil { + if opts.FailFast { + return err + } + retErr = errors.Join(retErr, err) + } + } + return retErr +} diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go new file mode 100644 index 000000000..7d8d70eec --- /dev/null +++ b/internal/validation/validation_test.go @@ -0,0 +1,212 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package validation + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidate(t *testing.T) { + testCases := map[string]struct { + doc Validatable + opts ValidateOptions + wantErr bool + errAssertion func(*assert.Assertions, error) bool + }{ + "valid": { + doc: &exampleDoc{ + StrField: "abc", + NumField: 42, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + opts: ValidateOptions{}, + }, + "strField is not abc": { + doc: &exampleDoc{ + StrField: "def", + NumField: 42, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc") + }, + opts: ValidateOptions{}, + }, + "numField is not 42": { + doc: &exampleDoc{ + StrField: "abc", + NumField: 43, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.numField: 43 must be equal to 42") + }, + }, + "multiple errors": { + doc: &exampleDoc{ + StrField: "def", + NumField: 43, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc") && + assert.Contains(err.Error(), "validating exampleDoc.numField: 43 must be equal to 42") + }, + opts: ValidateOptions{}, + }, + "multiple errors, fail fast": { + doc: &exampleDoc{ + StrField: "def", + NumField: 43, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.strField: def must be abc") + }, + opts: ValidateOptions{ + FailFast: true, + }, + }, + "map field is not empty": { + doc: &exampleDoc{ + StrField: "abc", + NumField: 42, + MapField: &map[string]string{ + "empty": "haha!", + }, + NotEmptyField: "certainly not.", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.mapField[\"empty\"]: haha! must be empty") + }, + opts: ValidateOptions{ + FailFast: true, + }, + }, + "empty field is not empty": { + doc: &exampleDoc{ + StrField: "abc", + NumField: 42, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "", + MatchRegexField: "abc", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.notEmptyField: must not be empty") + }, + opts: ValidateOptions{ + FailFast: true, + }, + }, + "regex doesnt match": { + doc: &exampleDoc{ + StrField: "abc", + NumField: 42, + MapField: &map[string]string{ + "empty": "", + }, + NotEmptyField: "certainly not!", + MatchRegexField: "dontmatch", + }, + wantErr: true, + errAssertion: func(assert *assert.Assertions, err error) bool { + return assert.Contains(err.Error(), "validating exampleDoc.matchRegexField: dontmatch must match the pattern ^a.c$") + }, + opts: ValidateOptions{ + FailFast: true, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + err := NewValidator().Validate(tc.doc, tc.opts) + if tc.wantErr { + require.Error(err) + if !tc.errAssertion(assert, err) { + t.Fatalf("unexpected error: %v", err) + } + } else { + require.NoError(err) + } + }) + } +} + +type exampleDoc struct { + StrField string `json:"strField"` + NumField int `json:"numField"` + MapField *map[string]string `json:"mapField"` + NotEmptyField string `json:"notEmptyField"` + MatchRegexField string `json:"matchRegexField"` +} + +// Constraints implements the Validatable interface. +func (d *exampleDoc) Constraints() []Constraint { + mapField := *(d.MapField) + + return []Constraint{ + d.strFieldNeedsToBeAbc(). + WithFieldTrace(d, &d.StrField), + Equal(d.NumField, 42). + WithFieldTrace(d, &d.NumField), + Empty(mapField["empty"]). + WithMapFieldTrace(d, d.MapField, "empty"), + NotEmpty(d.NotEmptyField). + WithFieldTrace(d, &d.NotEmptyField), + MatchRegex(d.MatchRegexField, "^a.c$"). + WithFieldTrace(d, &d.MatchRegexField), + } +} + +// StrFieldNeedsToBeAbc is an example for a custom constraint. +func (d *exampleDoc) strFieldNeedsToBeAbc() *Constraint { + return &Constraint{ + Satisfied: func() error { + if d.StrField != "abc" { + return fmt.Errorf("%s must be abc", d.StrField) + } + return nil + }, + } +}