measurements: refactor validation option (#1462)

Signed-off-by: Paul Meyer <49727155+katexochen@users.noreply.github.com>
This commit is contained in:
Paul Meyer 2023-03-22 06:47:39 -04:00 committed by GitHub
parent 1ab40b7ca6
commit 02fc3dc635
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 290 additions and 281 deletions

View file

@ -272,7 +272,9 @@ func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measu
// don't allow potential security downgrades by setting the warnOnly flag to true
for k, newM := range newMeasurements {
if currentM, ok := currentMeasurements[k]; ok && !currentM.WarnOnly && newM.WarnOnly {
if currentM, ok := currentMeasurements[k]; ok &&
currentM.ValidationOpt != measurements.WarnOnly &&
newM.ValidationOpt == measurements.WarnOnly {
return fmt.Errorf("setting enforced measurement %d to warn only: not allowed", k)
}
}

View file

@ -218,7 +218,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xBB, false),
0: measurements.WithAllBytes(0xBB, measurements.Enforce),
},
wantUpdate: true,
},
@ -231,7 +231,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, false),
0: measurements.WithAllBytes(0xAA, measurements.Enforce),
},
},
"trying to set warnOnly to true results in error": {
@ -243,7 +243,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, true),
0: measurements.WithAllBytes(0xAA, measurements.WarnOnly),
},
wantErr: true,
},
@ -256,7 +256,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, false),
0: measurements.WithAllBytes(0xAA, measurements.Enforce),
},
wantUpdate: true,
},

View file

@ -89,8 +89,8 @@ func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error {
oldExpected := v.pcrs[pcrIndex].Expected
expectedPcr := sha256.Sum256(append(oldExpected[:], hashedInput[:]...))
v.pcrs[pcrIndex] = measurements.Measurement{
Expected: expectedPcr,
WarnOnly: v.pcrs[pcrIndex].WarnOnly,
Expected: expectedPcr,
ValidationOpt: v.pcrs[pcrIndex].ValidationOpt,
}
return nil
}

View file

@ -29,12 +29,12 @@ import (
func TestNewValidator(t *testing.T) {
testPCRs := measurements.M{
0: measurements.WithAllBytes(0x00, false),
1: measurements.WithAllBytes(0xFF, false),
2: measurements.WithAllBytes(0x00, false),
3: measurements.WithAllBytes(0xFF, false),
4: measurements.WithAllBytes(0x00, false),
5: measurements.WithAllBytes(0x00, false),
0: measurements.WithAllBytes(0x00, measurements.Enforce),
1: measurements.WithAllBytes(0xFF, measurements.Enforce),
2: measurements.WithAllBytes(0x00, measurements.Enforce),
3: measurements.WithAllBytes(0xFF, measurements.Enforce),
4: measurements.WithAllBytes(0x00, measurements.Enforce),
5: measurements.WithAllBytes(0x00, measurements.Enforce),
}
testCases := map[string]struct {
@ -139,19 +139,19 @@ func TestNewValidator(t *testing.T) {
func TestValidatorV(t *testing.T) {
newTestPCRs := func() measurements.M {
return measurements.M{
0: measurements.WithAllBytes(0x00, true),
1: measurements.WithAllBytes(0x00, true),
2: measurements.WithAllBytes(0x00, true),
3: measurements.WithAllBytes(0x00, true),
4: measurements.WithAllBytes(0x00, true),
5: measurements.WithAllBytes(0x00, true),
6: measurements.WithAllBytes(0x00, true),
7: measurements.WithAllBytes(0x00, true),
8: measurements.WithAllBytes(0x00, true),
9: measurements.WithAllBytes(0x00, true),
10: measurements.WithAllBytes(0x00, true),
11: measurements.WithAllBytes(0x00, true),
12: measurements.WithAllBytes(0x00, true),
0: measurements.WithAllBytes(0x00, measurements.WarnOnly),
1: measurements.WithAllBytes(0x00, measurements.WarnOnly),
2: measurements.WithAllBytes(0x00, measurements.WarnOnly),
3: measurements.WithAllBytes(0x00, measurements.WarnOnly),
4: measurements.WithAllBytes(0x00, measurements.WarnOnly),
5: measurements.WithAllBytes(0x00, measurements.WarnOnly),
6: measurements.WithAllBytes(0x00, measurements.WarnOnly),
7: measurements.WithAllBytes(0x00, measurements.WarnOnly),
8: measurements.WithAllBytes(0x00, measurements.WarnOnly),
9: measurements.WithAllBytes(0x00, measurements.WarnOnly),
10: measurements.WithAllBytes(0x00, measurements.WarnOnly),
11: measurements.WithAllBytes(0x00, measurements.WarnOnly),
12: measurements.WithAllBytes(0x00, measurements.WarnOnly),
}
}
@ -200,37 +200,37 @@ func TestValidatorV(t *testing.T) {
}
func TestValidatorUpdateInitPCRs(t *testing.T) {
zero := measurements.WithAllBytes(0x00, true)
one := measurements.WithAllBytes(0x11, true)
zero := measurements.WithAllBytes(0x00, measurements.WarnOnly)
one := measurements.WithAllBytes(0x11, measurements.WarnOnly)
one64 := base64.StdEncoding.EncodeToString(one.Expected[:])
oneHash := sha256.Sum256(one.Expected[:])
pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...))
newTestPCRs := func() measurements.M {
return measurements.M{
0: measurements.WithAllBytes(0x00, true),
1: measurements.WithAllBytes(0x00, true),
2: measurements.WithAllBytes(0x00, true),
3: measurements.WithAllBytes(0x00, true),
4: measurements.WithAllBytes(0x00, true),
5: measurements.WithAllBytes(0x00, true),
6: measurements.WithAllBytes(0x00, true),
7: measurements.WithAllBytes(0x00, true),
8: measurements.WithAllBytes(0x00, true),
9: measurements.WithAllBytes(0x00, true),
10: measurements.WithAllBytes(0x00, true),
11: measurements.WithAllBytes(0x00, true),
12: measurements.WithAllBytes(0x00, true),
13: measurements.WithAllBytes(0x00, true),
14: measurements.WithAllBytes(0x00, true),
15: measurements.WithAllBytes(0x00, true),
16: measurements.WithAllBytes(0x00, true),
17: measurements.WithAllBytes(0x11, true),
18: measurements.WithAllBytes(0x11, true),
19: measurements.WithAllBytes(0x11, true),
20: measurements.WithAllBytes(0x11, true),
21: measurements.WithAllBytes(0x11, true),
22: measurements.WithAllBytes(0x11, true),
23: measurements.WithAllBytes(0x00, true),
0: measurements.WithAllBytes(0x00, measurements.WarnOnly),
1: measurements.WithAllBytes(0x00, measurements.WarnOnly),
2: measurements.WithAllBytes(0x00, measurements.WarnOnly),
3: measurements.WithAllBytes(0x00, measurements.WarnOnly),
4: measurements.WithAllBytes(0x00, measurements.WarnOnly),
5: measurements.WithAllBytes(0x00, measurements.WarnOnly),
6: measurements.WithAllBytes(0x00, measurements.WarnOnly),
7: measurements.WithAllBytes(0x00, measurements.WarnOnly),
8: measurements.WithAllBytes(0x00, measurements.WarnOnly),
9: measurements.WithAllBytes(0x00, measurements.WarnOnly),
10: measurements.WithAllBytes(0x00, measurements.WarnOnly),
11: measurements.WithAllBytes(0x00, measurements.WarnOnly),
12: measurements.WithAllBytes(0x00, measurements.WarnOnly),
13: measurements.WithAllBytes(0x00, measurements.WarnOnly),
14: measurements.WithAllBytes(0x00, measurements.WarnOnly),
15: measurements.WithAllBytes(0x00, measurements.WarnOnly),
16: measurements.WithAllBytes(0x00, measurements.WarnOnly),
17: measurements.WithAllBytes(0x11, measurements.WarnOnly),
18: measurements.WithAllBytes(0x11, measurements.WarnOnly),
19: measurements.WithAllBytes(0x11, measurements.WarnOnly),
20: measurements.WithAllBytes(0x11, measurements.WarnOnly),
21: measurements.WithAllBytes(0x11, measurements.WarnOnly),
22: measurements.WithAllBytes(0x11, measurements.WarnOnly),
23: measurements.WithAllBytes(0x00, measurements.WarnOnly),
}
}
@ -335,8 +335,8 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
func TestUpdatePCR(t *testing.T) {
emptyMap := measurements.M{}
defaultMap := measurements.M{
0: measurements.WithAllBytes(0xAA, false),
1: measurements.WithAllBytes(0xBB, false),
0: measurements.WithAllBytes(0xAA, measurements.Enforce),
1: measurements.WithAllBytes(0xBB, measurements.Enforce),
}
testCases := map[string]struct {

View file

@ -438,13 +438,13 @@ func TestAttestation(t *testing.T) {
cfg.Image = "image"
cfg.AttestationVariant = oid.QEMUVTPM{}.String()
cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Measurements[0] = measurements.WithAllBytes(0x00, false)
cfg.Provider.QEMU.Measurements[1] = measurements.WithAllBytes(0x11, false)
cfg.Provider.QEMU.Measurements[2] = measurements.WithAllBytes(0x22, false)
cfg.Provider.QEMU.Measurements[3] = measurements.WithAllBytes(0x33, false)
cfg.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
cfg.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x99, false)
cfg.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
cfg.Provider.QEMU.Measurements[0] = measurements.WithAllBytes(0x00, measurements.Enforce)
cfg.Provider.QEMU.Measurements[1] = measurements.WithAllBytes(0x11, measurements.Enforce)
cfg.Provider.QEMU.Measurements[2] = measurements.WithAllBytes(0x22, measurements.Enforce)
cfg.Provider.QEMU.Measurements[3] = measurements.WithAllBytes(0x33, measurements.Enforce)
cfg.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce)
cfg.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x99, measurements.Enforce)
cfg.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background()
@ -538,23 +538,23 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.Azure.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.Azure.Measurements[12] = measurements.WithAllBytes(0xcc, false)
conf.Provider.Azure.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce)
conf.Provider.Azure.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce)
conf.Provider.Azure.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce)
case cloudprovider.GCP:
conf.AttestationVariant = oid.GCPSEVES{}.String()
conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.GCP.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.GCP.Measurements[12] = measurements.WithAllBytes(0xcc, false)
conf.Provider.GCP.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce)
conf.Provider.GCP.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce)
conf.Provider.GCP.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce)
case cloudprovider.QEMU:
conf.AttestationVariant = oid.QEMUVTPM{}.String()
conf.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
conf.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce)
conf.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce)
conf.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce)
}
conf.RemoveProviderExcept(csp)

View file

@ -396,7 +396,7 @@ func prepareGCPValues(values map[string]any) error {
}
m := measurements.M{
1: measurements.WithAllBytes(0xAA, false),
1: measurements.WithAllBytes(0xAA, measurements.Enforce),
}
mJSON, err := json.Marshal(m)
if err != nil {
@ -471,7 +471,7 @@ func prepareOpenStackValues(values map[string]any) error {
if !ok {
return errors.New("missing 'join-service' key")
}
m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
m := measurements.M{1: measurements.WithAllBytes(0xAA, measurements.Enforce)}
mJSON, err := json.Marshal(m)
if err != nil {
return err
@ -506,7 +506,7 @@ func prepareQEMUValues(values map[string]any) error {
if !ok {
return errors.New("missing 'join-service' key")
}
m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
m := measurements.M{1: measurements.WithAllBytes(0xAA, measurements.Enforce)}
mJSON, err := json.Marshal(m)
if err != nil {
return err