From 14103e4f89dfd075f3b6a654fded65824b47425c Mon Sep 17 00:00:00 2001 From: Fabian Kammel Date: Thu, 12 May 2022 10:15:00 +0200 Subject: [PATCH] Fix/config/measurements in yaml (#135) Custom type & marshal implementation for measurements to write base64 instead of single bytes --- cli/cloud/cloudcmd/validators_test.go | 9 +- internal/config/config.go | 22 ++--- internal/config/measurements.go | 33 +++++++ internal/config/measurements_test.go | 129 ++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 16 deletions(-) create mode 100644 internal/config/measurements.go create mode 100644 internal/config/measurements_test.go diff --git a/cli/cloud/cloudcmd/validators_test.go b/cli/cloud/cloudcmd/validators_test.go index bc0a01044..913e2e959 100644 --- a/cli/cloud/cloudcmd/validators_test.go +++ b/cli/cloud/cloudcmd/validators_test.go @@ -68,13 +68,16 @@ func TestNewValidators(t *testing.T) { conf := &config.Config{Provider: &config.ProviderConfig{}} if tc.provider == cloudprovider.GCP { - conf.Provider.GCP = &config.GCPConfig{Measurements: &tc.pcrs} + measurements := config.Measurements(tc.pcrs) + conf.Provider.GCP = &config.GCPConfig{Measurements: &measurements} } if tc.provider == cloudprovider.Azure { - conf.Provider.Azure = &config.AzureConfig{Measurements: &tc.pcrs} + measurements := config.Measurements(tc.pcrs) + conf.Provider.Azure = &config.AzureConfig{Measurements: &measurements} } if tc.provider == cloudprovider.QEMU { - conf.Provider.QEMU = &config.QEMUConfig{PCRs: &tc.pcrs} + measurements := config.Measurements(tc.pcrs) + conf.Provider.QEMU = &config.QEMUConfig{PCRs: &measurements} } validators, err := NewValidators(tc.provider, conf) diff --git a/internal/config/config.go b/internal/config/config.go index 9f021f7dd..3c2d67078 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,7 +18,7 @@ import ( var ( // gcpPCRs is a map of the expected PCR values for a GCP Constellation node. // TODO: Get a full list once we have stable releases. - gcpPCRs = map[uint32][]byte{ + gcpPCRs = Measurements{ 0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF}, uint32(vtpm.PCRIndexOwnerID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint32(vtpm.PCRIndexClusterID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, @@ -26,12 +26,12 @@ var ( // azurePCRs is a map of the expected PCR values for an Azure Constellation node. // TODO: Get a full list once we have a working setup with stable releases. - azurePCRs = map[uint32][]byte{ + azurePCRs = Measurements{ uint32(vtpm.PCRIndexOwnerID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint32(vtpm.PCRIndexClusterID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, } - qemuPCRs = map[uint32][]byte{ + qemuPCRs = Measurements{ uint32(vtpm.PCRIndexOwnerID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint32(vtpm.PCRIndexClusterID): {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, } @@ -148,7 +148,7 @@ func Default() *Config { }, }, }, - Measurements: pcrPtr(azurePCRs), + Measurements: &azurePCRs, UserAssignedIdentity: proto.String("/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/constellation-images/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-dev-identity"), }, GCP: &GCPConfig{ @@ -196,10 +196,10 @@ func Default() *Config { "roles/storage.admin", "roles/iam.serviceAccountUser", }, - Measurements: pcrPtr(gcpPCRs), + Measurements: &gcpPCRs, }, QEMU: &QEMUConfig{ - PCRs: pcrPtr(qemuPCRs), + PCRs: &qemuPCRs, }, }, } @@ -241,7 +241,7 @@ type AzureConfig struct { Location *string `yaml:"location,omitempty"` // TODO: This will be user input Image *string `yaml:"image,omitempty"` NetworkSecurityGroupInput *azureClient.NetworkSecurityGroupInput `yaml:"networkSecurityGroupInput,omitempty"` - Measurements *map[uint32][]byte `yaml:"measurements,omitempty"` + Measurements *Measurements `yaml:"measurements,omitempty"` UserAssignedIdentity *string `yaml:"userassignedIdentity,omitempty"` } @@ -254,15 +254,11 @@ type GCPConfig struct { FirewallInput *gcpClient.FirewallInput `yaml:"firewallInput,omitempty"` VPCsInput *gcpClient.VPCsInput `yaml:"vpcsInput,omitempty"` ServiceAccountRoles *[]string `yaml:"serviceAccountRoles,omitempty"` - Measurements *map[uint32][]byte `yaml:"measurements,omitempty"` + Measurements *Measurements `yaml:"measurements,omitempty"` } type QEMUConfig struct { - PCRs *map[uint32][]byte `yaml:"pcrs,omitempty"` -} - -func pcrPtr(pcrs map[uint32][]byte) *map[uint32][]byte { - return &pcrs + PCRs *Measurements `yaml:"pcrs,omitempty"` } // intPtr returns a pointer to the copied value of in. diff --git a/internal/config/measurements.go b/internal/config/measurements.go new file mode 100644 index 000000000..2663f7ef7 --- /dev/null +++ b/internal/config/measurements.go @@ -0,0 +1,33 @@ +package config + +import "encoding/base64" + +type Measurements map[uint32][]byte + +func (m Measurements) MarshalYAML() (interface{}, error) { + base64Map := make(map[uint32]string) + + for key, value := range m { + base64Map[key] = base64.StdEncoding.EncodeToString(value[:]) + } + + return base64Map, nil +} + +func (m *Measurements) UnmarshalYAML(unmarshal func(interface{}) error) error { + base64Map := make(map[uint32]string) + err := unmarshal(base64Map) + if err != nil { + return err + } + + *m = make(Measurements) + for key, value := range base64Map { + measurement, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return err + } + (*m)[key] = measurement + } + return nil +} diff --git a/internal/config/measurements_test.go b/internal/config/measurements_test.go new file mode 100644 index 000000000..71bf8ecde --- /dev/null +++ b/internal/config/measurements_test.go @@ -0,0 +1,129 @@ +package config + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMarshalYAML(t *testing.T) { + testCases := map[string]struct { + measurements Measurements + wantBase64Map map[uint32]string + }{ + "valid measurements": { + measurements: Measurements{ + 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, + 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + }, + wantBase64Map: map[uint32]string{ + 2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", + 3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=", + }, + }, + "omit bytes": { + measurements: Measurements{ + 2: []byte{}, + 3: []byte{1, 2, 3, 4}, + }, + wantBase64Map: map[uint32]string{ + 2: "", + 3: "AQIDBA==", + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + base64Map, err := tc.measurements.MarshalYAML() + require.NoError(err) + + assert.Equal(tc.wantBase64Map, base64Map) + }) + } +} + +func TestUnmarshalYAML(t *testing.T) { + testCases := map[string]struct { + inputBase64Map map[uint32]string + forceUnmarshalError bool + wantMeasurements Measurements + wantErr bool + }{ + "valid measurements": { + inputBase64Map: map[uint32]string{ + 2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", + 3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=", + }, + wantMeasurements: Measurements{ + 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, + 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + }, + }, + "empty bytes": { + inputBase64Map: map[uint32]string{ + 2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + }, + wantMeasurements: Measurements{ + 2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + 3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }, + "invalid base64": { + inputBase64Map: map[uint32]string{ + 2: "This is not base64", + 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + }, + wantMeasurements: Measurements{ + 2: []byte{}, + 3: []byte{1, 2, 3, 4}, + }, + wantErr: true, + }, + "simulated unmarshal error": { + inputBase64Map: map[uint32]string{ + 2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + }, + forceUnmarshalError: true, + wantMeasurements: Measurements{ + 2: []byte{}, + 3: []byte{1, 2, 3, 4}, + }, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + var m Measurements + err := m.UnmarshalYAML(func(i interface{}) error { + if base64Map, ok := i.(map[uint32]string); ok { + for key, value := range tc.inputBase64Map { + base64Map[key] = value + } + } + if tc.forceUnmarshalError { + return errors.New("unmarshal error") + } + return nil + }) + + if tc.wantErr { + assert.Error(err) + } else { + require.NoError(err) + assert.Equal(tc.wantMeasurements, m) + } + }) + } +}