diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index 31dcb21dc..8f0eea2b6 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -33,12 +33,11 @@ import ( func TestUpgradeApply(t *testing.T) { someErr := errors.New("some error") testCases := map[string]struct { - upgrader stubUpgrader - fetcher stubImageFetcher - wantErr bool - yesFlag bool - stdin string - remoteAttestationCfg config.AttestationCfg // attestation config returned by the stub Kubernetes client + upgrader stubUpgrader + fetcher stubImageFetcher + wantErr bool + yesFlag bool + stdin string }{ "success": { upgrader: stubUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, @@ -143,16 +142,15 @@ func TestUpgradeApply(t *testing.T) { handler := file.NewHandler(afero.NewMemMapFs()) cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure) - if tc.remoteAttestationCfg == nil { - tc.remoteAttestationCfg = fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure) - } + remoteAttestationCfg := fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure) + require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg)) require.NoError(handler.WriteJSON(constants.ClusterIDsFilename, clusterid.File{})) upgrader := upgradeApplyCmd{upgrader: tc.upgrader, log: logger.NewTest(t), imageFetcher: tc.fetcher, configFetcher: stubAttestationFetcher{}} stubStableClientFactory := func(_ string) (getConfigMapper, error) { - return stubGetConfigMap{tc.remoteAttestationCfg}, nil + return stubGetConfigMap{remoteAttestationCfg}, nil } err := upgrader.upgradeApply(cmd, handler, stubStableClientFactory) if tc.wantErr { diff --git a/internal/config/attestationversion.go b/internal/config/attestationversion.go index 4711701de..a7949c5c3 100644 --- a/internal/config/attestationversion.go +++ b/internal/config/attestationversion.go @@ -58,15 +58,24 @@ func (v AttestationVersion) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements a custom unmarshaller to resolve "latest" values. func (v *AttestationVersion) UnmarshalJSON(data []byte) (err error) { - var rawUnmarshal string - if err := json.Unmarshal(data, &rawUnmarshal); err != nil { - return fmt.Errorf("raw unmarshal: %w", err) + // JSON has two distinct ways to represent numbers and strings. + // This means we cannot simply unmarshal to string, like with YAML. + // Unmarshalling to `any` causes Go to unmarshal numbers to float64. + // Therefore, try to unmarshal to string, and then to int, instead of using type assertions. + var unmarshalString string + if err := json.Unmarshal(data, &unmarshalString); err != nil { + var unmarshalInt int64 + if err := json.Unmarshal(data, &unmarshalInt); err != nil { + return fmt.Errorf("unable to unmarshal to string or int: %w", err) + } + unmarshalString = strconv.FormatInt(unmarshalInt, 10) } - return v.parseRawUnmarshal(rawUnmarshal) + + return v.parseRawUnmarshal(unmarshalString) } func (v *AttestationVersion) parseRawUnmarshal(str string) error { - if strings.HasPrefix(str, "0") { + if strings.HasPrefix(str, "0") && len(str) != 1 { return fmt.Errorf("no format with prefixed 0 (octal, hexadecimal) allowed: %s", str) } if strings.ToLower(str) == "latest" { diff --git a/internal/config/attestationversion_test.go b/internal/config/attestationversion_test.go index 0735671d9..52d68e2a8 100644 --- a/internal/config/attestationversion_test.go +++ b/internal/config/attestationversion_test.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only package config import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" @@ -14,21 +15,18 @@ import ( ) func TestVersionMarshalYAML(t *testing.T) { - tests := []struct { - name string + tests := map[string]struct { sut AttestationVersion want string }{ - { - name: "isLatest resolves to latest", + "isLatest resolves to latest": { sut: AttestationVersion{ Value: 1, WantLatest: true, }, want: "latest\n", }, - { - name: "value 5 resolves to 5", + "value 5 resolves to 5": { sut: AttestationVersion{ Value: 5, WantLatest: false, @@ -36,81 +34,177 @@ func TestVersionMarshalYAML(t *testing.T) { want: "5\n", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bt, err := yaml.Marshal(tt.sut) - require.NoError(t, err) - require.Equal(t, tt.want, string(bt)) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + require := require.New(t) + + bt, err := yaml.Marshal(tc.sut) + require.NoError(err) + require.Equal(tc.want, string(bt)) }) } } func TestVersionUnmarshalYAML(t *testing.T) { - tests := []struct { - name string + tests := map[string]struct { sut string want AttestationVersion wantErr bool }{ - { - name: "latest resolves to isLatest", - sut: "latest", + "latest resolves to isLatest": { + sut: "latest", want: AttestationVersion{ Value: 0, WantLatest: true, }, wantErr: false, }, - { - name: "1 resolves to value 1", - sut: "1", + "1 resolves to value 1": { + sut: "1", want: AttestationVersion{ Value: 1, WantLatest: false, }, wantErr: false, }, - { - name: "max uint8+1 errors", + "max uint8+1 errors": { sut: "256", wantErr: true, }, - { - name: "-1 errors", + "-1 errors": { sut: "-1", wantErr: true, }, - { - name: "2.6 errors", + "2.6 errors": { sut: "2.6", wantErr: true, }, - { - name: "2.0 errors", + "2.0 errors": { sut: "2.0", wantErr: true, }, - { - name: "hex format is invalid", + "hex format is invalid": { sut: "0x10", wantErr: true, }, - { - name: "octal format is invalid", + "octal format is invalid": { sut: "010", wantErr: true, }, + "0 resolves to value 0": { + sut: "0", + want: AttestationVersion{ + Value: 0, + WantLatest: false, + }, + }, + "00 errors": { + sut: "00", + wantErr: true, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + require := require.New(t) + var sut AttestationVersion - err := yaml.Unmarshal([]byte(tt.sut), &sut) - if tt.wantErr { - require.Error(t, err) + err := yaml.Unmarshal([]byte(tc.sut), &sut) + if tc.wantErr { + require.Error(err) return } - require.NoError(t, err) - require.Equal(t, tt.want, sut) + require.NoError(err) + require.Equal(tc.want, sut) + }) + } +} + +func TestVersionUnmarshalJSON(t *testing.T) { + tests := map[string]struct { + sut string + want AttestationVersion + wantErr bool + }{ + "latest resolves to isLatest": { + sut: `"latest"`, + want: AttestationVersion{ + Value: 0, + WantLatest: true, + }, + }, + "1 resolves to value 1": { + sut: "1", + want: AttestationVersion{ + Value: 1, + WantLatest: false, + }, + }, + "quoted number resolves to value": { + sut: `"1"`, + want: AttestationVersion{ + Value: 1, + WantLatest: false, + }, + }, + "quoted float errors": { + sut: `"1.0"`, + wantErr: true, + }, + "max uint8+1 errors": { + sut: "256", + wantErr: true, + }, + "-1 errors": { + sut: "-1", + wantErr: true, + }, + "2.6 errors": { + sut: "2.6", + wantErr: true, + }, + "2.0 errors": { + sut: "2.0", + wantErr: true, + }, + "hex format is invalid": { + sut: "0x10", + wantErr: true, + }, + "octal format is invalid": { + sut: "010", + wantErr: true, + }, + "0 resolves to value 0": { + sut: "0", + want: AttestationVersion{ + Value: 0, + WantLatest: false, + }, + }, + "quoted 0 resolves to value 0": { + sut: `"0"`, + want: AttestationVersion{ + Value: 0, + WantLatest: false, + }, + }, + "00 errors": { + sut: "00", + wantErr: true, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + require := require.New(t) + + var sut AttestationVersion + err := json.Unmarshal([]byte(tc.sut), &sut) + if tc.wantErr { + require.Error(err) + return + } + require.NoError(err) + require.Equal(tc.want, sut) }) } }