internal: fix unmarshalling attestation version numbers from JSON (#2187)

* Fix unmarshalling attestation version numbers from JSON

* Add unit test for UnmarshalJSON

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-08-09 15:11:14 +02:00 committed by GitHub
parent 656cdbb4bb
commit c9cae643e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 53 deletions

View File

@ -33,12 +33,11 @@ import (
func TestUpgradeApply(t *testing.T) { func TestUpgradeApply(t *testing.T) {
someErr := errors.New("some error") someErr := errors.New("some error")
testCases := map[string]struct { testCases := map[string]struct {
upgrader stubUpgrader upgrader stubUpgrader
fetcher stubImageFetcher fetcher stubImageFetcher
wantErr bool wantErr bool
yesFlag bool yesFlag bool
stdin string stdin string
remoteAttestationCfg config.AttestationCfg // attestation config returned by the stub Kubernetes client
}{ }{
"success": { "success": {
upgrader: stubUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, upgrader: stubUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
@ -143,16 +142,15 @@ func TestUpgradeApply(t *testing.T) {
handler := file.NewHandler(afero.NewMemMapFs()) handler := file.NewHandler(afero.NewMemMapFs())
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure) cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
if tc.remoteAttestationCfg == nil { remoteAttestationCfg := fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure)
tc.remoteAttestationCfg = fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure)
}
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg)) require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
require.NoError(handler.WriteJSON(constants.ClusterIDsFilename, clusterid.File{})) require.NoError(handler.WriteJSON(constants.ClusterIDsFilename, clusterid.File{}))
upgrader := upgradeApplyCmd{upgrader: tc.upgrader, log: logger.NewTest(t), imageFetcher: tc.fetcher, configFetcher: stubAttestationFetcher{}} upgrader := upgradeApplyCmd{upgrader: tc.upgrader, log: logger.NewTest(t), imageFetcher: tc.fetcher, configFetcher: stubAttestationFetcher{}}
stubStableClientFactory := func(_ string) (getConfigMapper, error) { stubStableClientFactory := func(_ string) (getConfigMapper, error) {
return stubGetConfigMap{tc.remoteAttestationCfg}, nil return stubGetConfigMap{remoteAttestationCfg}, nil
} }
err := upgrader.upgradeApply(cmd, handler, stubStableClientFactory) err := upgrader.upgradeApply(cmd, handler, stubStableClientFactory)
if tc.wantErr { if tc.wantErr {

View File

@ -58,15 +58,24 @@ func (v AttestationVersion) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements a custom unmarshaller to resolve "latest" values. // UnmarshalJSON implements a custom unmarshaller to resolve "latest" values.
func (v *AttestationVersion) UnmarshalJSON(data []byte) (err error) { func (v *AttestationVersion) UnmarshalJSON(data []byte) (err error) {
var rawUnmarshal string // JSON has two distinct ways to represent numbers and strings.
if err := json.Unmarshal(data, &rawUnmarshal); err != nil { // This means we cannot simply unmarshal to string, like with YAML.
return fmt.Errorf("raw unmarshal: %w", err) // Unmarshalling to `any` causes Go to unmarshal numbers to float64.
// Therefore, try to unmarshal to string, and then to int, instead of using type assertions.
var unmarshalString string
if err := json.Unmarshal(data, &unmarshalString); err != nil {
var unmarshalInt int64
if err := json.Unmarshal(data, &unmarshalInt); err != nil {
return fmt.Errorf("unable to unmarshal to string or int: %w", err)
}
unmarshalString = strconv.FormatInt(unmarshalInt, 10)
} }
return v.parseRawUnmarshal(rawUnmarshal)
return v.parseRawUnmarshal(unmarshalString)
} }
func (v *AttestationVersion) parseRawUnmarshal(str string) error { func (v *AttestationVersion) parseRawUnmarshal(str string) error {
if strings.HasPrefix(str, "0") { if strings.HasPrefix(str, "0") && len(str) != 1 {
return fmt.Errorf("no format with prefixed 0 (octal, hexadecimal) allowed: %s", str) return fmt.Errorf("no format with prefixed 0 (octal, hexadecimal) allowed: %s", str)
} }
if strings.ToLower(str) == "latest" { if strings.ToLower(str) == "latest" {

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package config package config
import ( import (
"encoding/json"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -14,21 +15,18 @@ import (
) )
func TestVersionMarshalYAML(t *testing.T) { func TestVersionMarshalYAML(t *testing.T) {
tests := []struct { tests := map[string]struct {
name string
sut AttestationVersion sut AttestationVersion
want string want string
}{ }{
{ "isLatest resolves to latest": {
name: "isLatest resolves to latest",
sut: AttestationVersion{ sut: AttestationVersion{
Value: 1, Value: 1,
WantLatest: true, WantLatest: true,
}, },
want: "latest\n", want: "latest\n",
}, },
{ "value 5 resolves to 5": {
name: "value 5 resolves to 5",
sut: AttestationVersion{ sut: AttestationVersion{
Value: 5, Value: 5,
WantLatest: false, WantLatest: false,
@ -36,81 +34,177 @@ func TestVersionMarshalYAML(t *testing.T) {
want: "5\n", want: "5\n",
}, },
} }
for _, tt := range tests { for name, tc := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
bt, err := yaml.Marshal(tt.sut) require := require.New(t)
require.NoError(t, err)
require.Equal(t, tt.want, string(bt)) bt, err := yaml.Marshal(tc.sut)
require.NoError(err)
require.Equal(tc.want, string(bt))
}) })
} }
} }
func TestVersionUnmarshalYAML(t *testing.T) { func TestVersionUnmarshalYAML(t *testing.T) {
tests := []struct { tests := map[string]struct {
name string
sut string sut string
want AttestationVersion want AttestationVersion
wantErr bool wantErr bool
}{ }{
{ "latest resolves to isLatest": {
name: "latest resolves to isLatest", sut: "latest",
sut: "latest",
want: AttestationVersion{ want: AttestationVersion{
Value: 0, Value: 0,
WantLatest: true, WantLatest: true,
}, },
wantErr: false, wantErr: false,
}, },
{ "1 resolves to value 1": {
name: "1 resolves to value 1", sut: "1",
sut: "1",
want: AttestationVersion{ want: AttestationVersion{
Value: 1, Value: 1,
WantLatest: false, WantLatest: false,
}, },
wantErr: false, wantErr: false,
}, },
{ "max uint8+1 errors": {
name: "max uint8+1 errors",
sut: "256", sut: "256",
wantErr: true, wantErr: true,
}, },
{ "-1 errors": {
name: "-1 errors",
sut: "-1", sut: "-1",
wantErr: true, wantErr: true,
}, },
{ "2.6 errors": {
name: "2.6 errors",
sut: "2.6", sut: "2.6",
wantErr: true, wantErr: true,
}, },
{ "2.0 errors": {
name: "2.0 errors",
sut: "2.0", sut: "2.0",
wantErr: true, wantErr: true,
}, },
{ "hex format is invalid": {
name: "hex format is invalid",
sut: "0x10", sut: "0x10",
wantErr: true, wantErr: true,
}, },
{ "octal format is invalid": {
name: "octal format is invalid",
sut: "010", sut: "010",
wantErr: true, wantErr: true,
}, },
"0 resolves to value 0": {
sut: "0",
want: AttestationVersion{
Value: 0,
WantLatest: false,
},
},
"00 errors": {
sut: "00",
wantErr: true,
},
} }
for _, tt := range tests { for name, tc := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
var sut AttestationVersion var sut AttestationVersion
err := yaml.Unmarshal([]byte(tt.sut), &sut) err := yaml.Unmarshal([]byte(tc.sut), &sut)
if tt.wantErr { if tc.wantErr {
require.Error(t, err) require.Error(err)
return return
} }
require.NoError(t, err) require.NoError(err)
require.Equal(t, tt.want, sut) require.Equal(tc.want, sut)
})
}
}
func TestVersionUnmarshalJSON(t *testing.T) {
tests := map[string]struct {
sut string
want AttestationVersion
wantErr bool
}{
"latest resolves to isLatest": {
sut: `"latest"`,
want: AttestationVersion{
Value: 0,
WantLatest: true,
},
},
"1 resolves to value 1": {
sut: "1",
want: AttestationVersion{
Value: 1,
WantLatest: false,
},
},
"quoted number resolves to value": {
sut: `"1"`,
want: AttestationVersion{
Value: 1,
WantLatest: false,
},
},
"quoted float errors": {
sut: `"1.0"`,
wantErr: true,
},
"max uint8+1 errors": {
sut: "256",
wantErr: true,
},
"-1 errors": {
sut: "-1",
wantErr: true,
},
"2.6 errors": {
sut: "2.6",
wantErr: true,
},
"2.0 errors": {
sut: "2.0",
wantErr: true,
},
"hex format is invalid": {
sut: "0x10",
wantErr: true,
},
"octal format is invalid": {
sut: "010",
wantErr: true,
},
"0 resolves to value 0": {
sut: "0",
want: AttestationVersion{
Value: 0,
WantLatest: false,
},
},
"quoted 0 resolves to value 0": {
sut: `"0"`,
want: AttestationVersion{
Value: 0,
WantLatest: false,
},
},
"00 errors": {
sut: "00",
wantErr: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
require := require.New(t)
var sut AttestationVersion
err := json.Unmarshal([]byte(tc.sut), &sut)
if tc.wantErr {
require.Error(err)
return
}
require.NoError(err)
require.Equal(tc.want, sut)
}) })
} }
} }