From 103817a4a566702c30c29a53361437bf45faeb3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Wed, 16 Aug 2023 10:45:54 +0200 Subject: [PATCH] attestation: print ordered measurement verification warnings and errors (#2237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- .../attestation/measurements/measurements.go | 31 +++++ .../measurements/measurements_test.go | 115 ++++++++++++++++++ internal/attestation/tdx/validator.go | 15 ++- internal/attestation/vtpm/attestation.go | 23 ++-- internal/attestation/vtpm/attestation_test.go | 25 +++- 5 files changed, 184 insertions(+), 25 deletions(-) diff --git a/internal/attestation/measurements/measurements.go b/internal/attestation/measurements/measurements.go index edf885cbf..6e775d06b 100644 --- a/internal/attestation/measurements/measurements.go +++ b/internal/attestation/measurements/measurements.go @@ -226,6 +226,37 @@ func (m *M) EqualTo(other M) bool { return true } +// Compare compares the expected measurements to the given list of measurements. +// It returns a list of warnings for non matching measurements for WarnOnly entries, +// and a list of errors for non matching measurements for Enforce entries. +func (m M) Compare(other map[uint32][]byte) (warnings []string, errs []error) { + // Get list of indices in expected measurements + var mIndices []uint32 + for idx := range m { + mIndices = append(mIndices, idx) + } + sort.SliceStable(mIndices, func(i, j int) bool { + return mIndices[i] < mIndices[j] + }) + + for _, idx := range mIndices { + if !bytes.Equal(m[idx].Expected, other[idx]) { + msg := fmt.Sprintf("untrusted measurement value %x at index %d", other[idx], idx) + if len(other[idx]) == 0 { + msg = fmt.Sprintf("missing measurement value for index %d", idx) + } + + if m[idx].ValidationOpt == Enforce { + errs = append(errs, errors.New(msg)) + } else { + warnings = append(warnings, fmt.Sprintf("Encountered %s", msg)) + } + } + } + + return warnings, errs +} + // GetEnforced returns a list of all enforced Measurements, // i.e. all Measurements that are not marked as WarnOnly. func (m *M) GetEnforced() []uint32 { diff --git a/internal/attestation/measurements/measurements_test.go b/internal/attestation/measurements/measurements_test.go index d08c95030..73cee7479 100644 --- a/internal/attestation/measurements/measurements_test.go +++ b/internal/attestation/measurements/measurements_test.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only package measurements import ( + "bytes" "context" "encoding/json" "io" @@ -928,3 +929,117 @@ func TestMergeImageMeasurementsV2(t *testing.T) { }) } } + +func TestMeasurementsCompare(t *testing.T) { + testCases := map[string]struct { + expected M + actual map[uint32][]byte + wantErrs int + wantWarnings int + }{ + "no errors": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength), + }, + wantErrs: 0, + wantWarnings: 0, + }, + "no errors, with warnings": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, WarnOnly, PCRMeasurementLength), + 2: WithAllBytes(0x22, WarnOnly, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + 2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + }, + wantErrs: 0, + wantWarnings: 2, + }, + "with errors, no warnings": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + 2: WithAllBytes(0x22, Enforce, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + 2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + }, + wantErrs: 2, + wantWarnings: 0, + }, + "with errors and warnings": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, WarnOnly, PCRMeasurementLength), + 2: WithAllBytes(0x22, Enforce, PCRMeasurementLength), + }, + + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + 2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength), + }, + wantErrs: 1, + wantWarnings: 1, + }, + "extra measurements don't cause errors": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength), + 2: bytes.Repeat([]byte{0x22}, PCRMeasurementLength), + }, + wantErrs: 0, + wantWarnings: 0, + }, + "missing measurements cause errors": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + 2: WithAllBytes(0x22, Enforce, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength), + }, + wantErrs: 1, + wantWarnings: 0, + }, + "missing measurements cause warnings": { + expected: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + 2: WithAllBytes(0x22, WarnOnly, PCRMeasurementLength), + }, + actual: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength), + 1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength), + }, + wantErrs: 0, + wantWarnings: 1, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + gotWarnings, gotErrs := tc.expected.Compare(tc.actual) + assert.Equal(tc.wantErrs, len(gotErrs)) + assert.Equal(tc.wantWarnings, len(gotWarnings)) + }) + } +} diff --git a/internal/attestation/tdx/validator.go b/internal/attestation/tdx/validator.go index ea1172170..0a18f1c9d 100644 --- a/internal/attestation/tdx/validator.go +++ b/internal/attestation/tdx/validator.go @@ -7,9 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only package tdx import ( - "bytes" "context" "encoding/json" + "errors" "fmt" "github.com/edgelesssys/constellation/v2/internal/attestation" @@ -81,13 +81,12 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte } // Verify the quote against the expected measurements. - for idx, ex := range v.expected { - if !bytes.Equal(ex.Expected, tdMeasure[idx]) { - if !ex.ValidationOpt { - return nil, fmt.Errorf("untrusted TD measurement value at index %d", idx) - } - v.log.Warnf("Encountered untrusted TD measurement value at index %d", idx) - } + warnings, errs := v.expected.Compare(tdMeasure) + for _, warning := range warnings { + v.log.Warnf(warning) + } + if len(errs) > 0 { + return nil, fmt.Errorf("measurement validation failed:\n%w", errors.Join(errs...)) } return attDoc.UserData, nil diff --git a/internal/attestation/vtpm/attestation.go b/internal/attestation/vtpm/attestation.go index e689e9c43..0d907dd7f 100644 --- a/internal/attestation/vtpm/attestation.go +++ b/internal/attestation/vtpm/attestation.go @@ -7,10 +7,10 @@ SPDX-License-Identifier: AGPL-3.0-only package vtpm import ( - "bytes" "context" "crypto" "encoding/json" + "errors" "fmt" "io" @@ -219,21 +219,12 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte if err != nil { return nil, err } - for idx, pcr := range v.expected { - if !bytes.Equal(pcr.Expected[:], attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) { - if pcr.ValidationOpt == measurements.Enforce { - return nil, fmt.Errorf( - "untrusted PCR value %x at index %d", - attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx], - idx, - ) - } - v.log.Warnf( - "Encountered untrusted PCR value %x at index %d", - attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx], - idx, - ) - } + warnings, errs := v.expected.Compare(attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs) + for _, warning := range warnings { + v.log.Warnf(warning) + } + if len(errs) > 0 { + return nil, fmt.Errorf("measurement validation failed:\n%w", errors.Join(errs...)) } v.log.Infof("Successfully validated attestation document") diff --git a/internal/attestation/vtpm/attestation_test.go b/internal/attestation/vtpm/attestation_test.go index 563404a2e..15128aee3 100644 --- a/internal/attestation/vtpm/attestation_test.go +++ b/internal/attestation/vtpm/attestation_test.go @@ -80,7 +80,7 @@ func TestValidate(t *testing.T) { defer tpmCloser.Close() issuer := NewIssuer(tpmOpen, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo, logger.NewTest(t)) - validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, nil) + validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, logger.NewTest(t)) nonce := []byte{1, 2, 3, 4} challenge := []byte("Constellation") @@ -206,6 +206,10 @@ func TestValidate(t *testing.T) { Expected: []byte{0xFF}, ValidationOpt: measurements.Enforce, }, + 1: measurements.Measurement{ + Expected: []byte{0xFF}, + ValidationOpt: measurements.Enforce, + }, }, fakeGetTrustedKey, fakeValidateCVM, @@ -214,6 +218,25 @@ func TestValidate(t *testing.T) { nonce: nonce, wantErr: true, }, + "untrusted WarnOnly PCRs": { + validator: NewValidator( + measurements.M{ + 0: measurements.Measurement{ + Expected: []byte{0xFF}, + ValidationOpt: measurements.WarnOnly, + }, + 1: measurements.Measurement{ + Expected: []byte{0xFF}, + ValidationOpt: measurements.WarnOnly, + }, + }, + fakeGetTrustedKey, + fakeValidateCVM, + logger.NewTest(t)), + attDoc: mustMarshalAttestation(attDoc, require), + nonce: nonce, + wantErr: false, + }, "no sha256 quote": { validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog), attDoc: mustMarshalAttestation(AttestationDocument{