strict input validation on attestation version numbers (#2180)

This commit is contained in:
Adrian Stobbe 2023-08-09 11:41:04 +02:00 committed by GitHub
parent d1febd7276
commit d8db9d0add
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 17 deletions

View File

@ -9,6 +9,8 @@ package config
import (
"encoding/json"
"fmt"
"math"
"strconv"
"strings"
)
@ -38,7 +40,7 @@ func (v AttestationVersion) MarshalYAML() (any, error) {
// UnmarshalYAML implements a custom unmarshaller to resolve "atest" values.
func (v *AttestationVersion) UnmarshalYAML(unmarshal func(any) error) error {
var rawUnmarshal any
var rawUnmarshal string
if err := unmarshal(&rawUnmarshal); err != nil {
return fmt.Errorf("raw unmarshal: %w", err)
}
@ -56,29 +58,29 @@ func (v AttestationVersion) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements a custom unmarshaller to resolve "latest" values.
func (v *AttestationVersion) UnmarshalJSON(data []byte) (err error) {
var rawUnmarshal any
var rawUnmarshal string
if err := json.Unmarshal(data, &rawUnmarshal); err != nil {
return fmt.Errorf("raw unmarshal: %w", err)
}
return v.parseRawUnmarshal(rawUnmarshal)
}
func (v *AttestationVersion) parseRawUnmarshal(rawUnmarshal any) error {
switch s := rawUnmarshal.(type) {
case string:
if strings.ToLower(s) == "latest" {
v.WantLatest = true
v.Value = placeholderVersionValue
} else {
return fmt.Errorf("invalid version value: %s", s)
func (v *AttestationVersion) parseRawUnmarshal(str string) error {
if strings.HasPrefix(str, "0") {
return fmt.Errorf("no format with prefixed 0 (octal, hexadecimal) allowed: %s", str)
}
if strings.ToLower(str) == "latest" {
v.WantLatest = true
v.Value = placeholderVersionValue
} else {
ui, err := strconv.ParseUint(str, 10, 8)
if err != nil {
return fmt.Errorf("invalid version value: %s", str)
}
case int:
v.Value = uint8(s)
// yaml spec allows "1" as float64, so version number might come as a float: https://github.com/go-yaml/yaml/issues/430
case float64:
v.Value = uint8(s)
default:
return fmt.Errorf("invalid version value type: %s", s)
if ui > math.MaxUint8 {
return fmt.Errorf("integer value is out ouf uint8 range: %d", ui)
}
v.Value = uint8(ui)
}
return nil
}

View File

@ -44,3 +44,73 @@ func TestVersionMarshalYAML(t *testing.T) {
})
}
}
func TestVersionUnmarshalYAML(t *testing.T) {
tests := []struct {
name string
sut string
want AttestationVersion
wantErr bool
}{
{
name: "latest resolves to isLatest",
sut: "latest",
want: AttestationVersion{
Value: 0,
WantLatest: true,
},
wantErr: false,
},
{
name: "1 resolves to value 1",
sut: "1",
want: AttestationVersion{
Value: 1,
WantLatest: false,
},
wantErr: false,
},
{
name: "max uint8+1 errors",
sut: "256",
wantErr: true,
},
{
name: "-1 errors",
sut: "-1",
wantErr: true,
},
{
name: "2.6 errors",
sut: "2.6",
wantErr: true,
},
{
name: "2.0 errors",
sut: "2.0",
wantErr: true,
},
{
name: "hex format is invalid",
sut: "0x10",
wantErr: true,
},
{
name: "octal format is invalid",
sut: "010",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var sut AttestationVersion
err := yaml.Unmarshal([]byte(tt.sut), &sut)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, sut)
})
}
}