mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-02-02 10:35:08 -05:00
internal: fix unmarshalling attestation version numbers from JSON (#2187)
* Fix unmarshalling attestation version numbers from JSON * Add unit test for UnmarshalJSON --------- Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
656cdbb4bb
commit
c9cae643e2
@ -33,12 +33,11 @@ import (
|
||||
func TestUpgradeApply(t *testing.T) {
|
||||
someErr := errors.New("some error")
|
||||
testCases := map[string]struct {
|
||||
upgrader stubUpgrader
|
||||
fetcher stubImageFetcher
|
||||
wantErr bool
|
||||
yesFlag bool
|
||||
stdin string
|
||||
remoteAttestationCfg config.AttestationCfg // attestation config returned by the stub Kubernetes client
|
||||
upgrader stubUpgrader
|
||||
fetcher stubImageFetcher
|
||||
wantErr bool
|
||||
yesFlag bool
|
||||
stdin string
|
||||
}{
|
||||
"success": {
|
||||
upgrader: stubUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
|
||||
@ -143,16 +142,15 @@ func TestUpgradeApply(t *testing.T) {
|
||||
handler := file.NewHandler(afero.NewMemMapFs())
|
||||
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
|
||||
|
||||
if tc.remoteAttestationCfg == nil {
|
||||
tc.remoteAttestationCfg = fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure)
|
||||
}
|
||||
remoteAttestationCfg := fakeAttestationConfigFromCluster(cmd.Context(), t, cloudprovider.Azure)
|
||||
|
||||
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
|
||||
require.NoError(handler.WriteJSON(constants.ClusterIDsFilename, clusterid.File{}))
|
||||
|
||||
upgrader := upgradeApplyCmd{upgrader: tc.upgrader, log: logger.NewTest(t), imageFetcher: tc.fetcher, configFetcher: stubAttestationFetcher{}}
|
||||
|
||||
stubStableClientFactory := func(_ string) (getConfigMapper, error) {
|
||||
return stubGetConfigMap{tc.remoteAttestationCfg}, nil
|
||||
return stubGetConfigMap{remoteAttestationCfg}, nil
|
||||
}
|
||||
err := upgrader.upgradeApply(cmd, handler, stubStableClientFactory)
|
||||
if tc.wantErr {
|
||||
|
@ -58,15 +58,24 @@ func (v AttestationVersion) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshaller to resolve "latest" values.
|
||||
func (v *AttestationVersion) UnmarshalJSON(data []byte) (err error) {
|
||||
var rawUnmarshal string
|
||||
if err := json.Unmarshal(data, &rawUnmarshal); err != nil {
|
||||
return fmt.Errorf("raw unmarshal: %w", err)
|
||||
// JSON has two distinct ways to represent numbers and strings.
|
||||
// This means we cannot simply unmarshal to string, like with YAML.
|
||||
// Unmarshalling to `any` causes Go to unmarshal numbers to float64.
|
||||
// Therefore, try to unmarshal to string, and then to int, instead of using type assertions.
|
||||
var unmarshalString string
|
||||
if err := json.Unmarshal(data, &unmarshalString); err != nil {
|
||||
var unmarshalInt int64
|
||||
if err := json.Unmarshal(data, &unmarshalInt); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal to string or int: %w", err)
|
||||
}
|
||||
unmarshalString = strconv.FormatInt(unmarshalInt, 10)
|
||||
}
|
||||
return v.parseRawUnmarshal(rawUnmarshal)
|
||||
|
||||
return v.parseRawUnmarshal(unmarshalString)
|
||||
}
|
||||
|
||||
func (v *AttestationVersion) parseRawUnmarshal(str string) error {
|
||||
if strings.HasPrefix(str, "0") {
|
||||
if strings.HasPrefix(str, "0") && len(str) != 1 {
|
||||
return fmt.Errorf("no format with prefixed 0 (octal, hexadecimal) allowed: %s", str)
|
||||
}
|
||||
if strings.ToLower(str) == "latest" {
|
||||
|
@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -14,21 +15,18 @@ import (
|
||||
)
|
||||
|
||||
func TestVersionMarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tests := map[string]struct {
|
||||
sut AttestationVersion
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "isLatest resolves to latest",
|
||||
"isLatest resolves to latest": {
|
||||
sut: AttestationVersion{
|
||||
Value: 1,
|
||||
WantLatest: true,
|
||||
},
|
||||
want: "latest\n",
|
||||
},
|
||||
{
|
||||
name: "value 5 resolves to 5",
|
||||
"value 5 resolves to 5": {
|
||||
sut: AttestationVersion{
|
||||
Value: 5,
|
||||
WantLatest: false,
|
||||
@ -36,81 +34,177 @@ func TestVersionMarshalYAML(t *testing.T) {
|
||||
want: "5\n",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bt, err := yaml.Marshal(tt.sut)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, string(bt))
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
bt, err := yaml.Marshal(tc.sut)
|
||||
require.NoError(err)
|
||||
require.Equal(tc.want, string(bt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionUnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tests := map[string]struct {
|
||||
sut string
|
||||
want AttestationVersion
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "latest resolves to isLatest",
|
||||
sut: "latest",
|
||||
"latest resolves to isLatest": {
|
||||
sut: "latest",
|
||||
want: AttestationVersion{
|
||||
Value: 0,
|
||||
WantLatest: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1 resolves to value 1",
|
||||
sut: "1",
|
||||
"1 resolves to value 1": {
|
||||
sut: "1",
|
||||
want: AttestationVersion{
|
||||
Value: 1,
|
||||
WantLatest: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "max uint8+1 errors",
|
||||
"max uint8+1 errors": {
|
||||
sut: "256",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "-1 errors",
|
||||
"-1 errors": {
|
||||
sut: "-1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "2.6 errors",
|
||||
"2.6 errors": {
|
||||
sut: "2.6",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "2.0 errors",
|
||||
"2.0 errors": {
|
||||
sut: "2.0",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hex format is invalid",
|
||||
"hex format is invalid": {
|
||||
sut: "0x10",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "octal format is invalid",
|
||||
"octal format is invalid": {
|
||||
sut: "010",
|
||||
wantErr: true,
|
||||
},
|
||||
"0 resolves to value 0": {
|
||||
sut: "0",
|
||||
want: AttestationVersion{
|
||||
Value: 0,
|
||||
WantLatest: false,
|
||||
},
|
||||
},
|
||||
"00 errors": {
|
||||
sut: "00",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var sut AttestationVersion
|
||||
err := yaml.Unmarshal([]byte(tt.sut), &sut)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
err := yaml.Unmarshal([]byte(tc.sut), &sut)
|
||||
if tc.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, sut)
|
||||
require.NoError(err)
|
||||
require.Equal(tc.want, sut)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionUnmarshalJSON(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
sut string
|
||||
want AttestationVersion
|
||||
wantErr bool
|
||||
}{
|
||||
"latest resolves to isLatest": {
|
||||
sut: `"latest"`,
|
||||
want: AttestationVersion{
|
||||
Value: 0,
|
||||
WantLatest: true,
|
||||
},
|
||||
},
|
||||
"1 resolves to value 1": {
|
||||
sut: "1",
|
||||
want: AttestationVersion{
|
||||
Value: 1,
|
||||
WantLatest: false,
|
||||
},
|
||||
},
|
||||
"quoted number resolves to value": {
|
||||
sut: `"1"`,
|
||||
want: AttestationVersion{
|
||||
Value: 1,
|
||||
WantLatest: false,
|
||||
},
|
||||
},
|
||||
"quoted float errors": {
|
||||
sut: `"1.0"`,
|
||||
wantErr: true,
|
||||
},
|
||||
"max uint8+1 errors": {
|
||||
sut: "256",
|
||||
wantErr: true,
|
||||
},
|
||||
"-1 errors": {
|
||||
sut: "-1",
|
||||
wantErr: true,
|
||||
},
|
||||
"2.6 errors": {
|
||||
sut: "2.6",
|
||||
wantErr: true,
|
||||
},
|
||||
"2.0 errors": {
|
||||
sut: "2.0",
|
||||
wantErr: true,
|
||||
},
|
||||
"hex format is invalid": {
|
||||
sut: "0x10",
|
||||
wantErr: true,
|
||||
},
|
||||
"octal format is invalid": {
|
||||
sut: "010",
|
||||
wantErr: true,
|
||||
},
|
||||
"0 resolves to value 0": {
|
||||
sut: "0",
|
||||
want: AttestationVersion{
|
||||
Value: 0,
|
||||
WantLatest: false,
|
||||
},
|
||||
},
|
||||
"quoted 0 resolves to value 0": {
|
||||
sut: `"0"`,
|
||||
want: AttestationVersion{
|
||||
Value: 0,
|
||||
WantLatest: false,
|
||||
},
|
||||
},
|
||||
"00 errors": {
|
||||
sut: "00",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var sut AttestationVersion
|
||||
err := json.Unmarshal([]byte(tc.sut), &sut)
|
||||
if tc.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal(tc.want, sut)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user