attestation: print ordered measurement verification warnings and errors (#2237)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-08-16 10:45:54 +02:00 committed by GitHub
parent 78fa921746
commit 103817a4a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 25 deletions

View File

@ -226,6 +226,37 @@ func (m *M) EqualTo(other M) bool {
return true
}
// Compare compares the expected measurements to the given list of measurements.
// It returns a list of warnings for non matching measurements for WarnOnly entries,
// and a list of errors for non matching measurements for Enforce entries.
func (m M) Compare(other map[uint32][]byte) (warnings []string, errs []error) {
// Get list of indices in expected measurements
var mIndices []uint32
for idx := range m {
mIndices = append(mIndices, idx)
}
sort.SliceStable(mIndices, func(i, j int) bool {
return mIndices[i] < mIndices[j]
})
for _, idx := range mIndices {
if !bytes.Equal(m[idx].Expected, other[idx]) {
msg := fmt.Sprintf("untrusted measurement value %x at index %d", other[idx], idx)
if len(other[idx]) == 0 {
msg = fmt.Sprintf("missing measurement value for index %d", idx)
}
if m[idx].ValidationOpt == Enforce {
errs = append(errs, errors.New(msg))
} else {
warnings = append(warnings, fmt.Sprintf("Encountered %s", msg))
}
}
}
return warnings, errs
}
// GetEnforced returns a list of all enforced Measurements,
// i.e. all Measurements that are not marked as WarnOnly.
func (m *M) GetEnforced() []uint32 {

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package measurements
import (
"bytes"
"context"
"encoding/json"
"io"
@ -928,3 +929,117 @@ func TestMergeImageMeasurementsV2(t *testing.T) {
})
}
}
func TestMeasurementsCompare(t *testing.T) {
testCases := map[string]struct {
expected M
actual map[uint32][]byte
wantErrs int
wantWarnings int
}{
"no errors": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, Enforce, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength),
},
wantErrs: 0,
wantWarnings: 0,
},
"no errors, with warnings": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, WarnOnly, PCRMeasurementLength),
2: WithAllBytes(0x22, WarnOnly, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
},
wantErrs: 0,
wantWarnings: 2,
},
"with errors, no warnings": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, Enforce, PCRMeasurementLength),
2: WithAllBytes(0x22, Enforce, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
},
wantErrs: 2,
wantWarnings: 0,
},
"with errors and warnings": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, WarnOnly, PCRMeasurementLength),
2: WithAllBytes(0x22, Enforce, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
2: bytes.Repeat([]byte{0xFF}, PCRMeasurementLength),
},
wantErrs: 1,
wantWarnings: 1,
},
"extra measurements don't cause errors": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, Enforce, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength),
2: bytes.Repeat([]byte{0x22}, PCRMeasurementLength),
},
wantErrs: 0,
wantWarnings: 0,
},
"missing measurements cause errors": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, Enforce, PCRMeasurementLength),
2: WithAllBytes(0x22, Enforce, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength),
},
wantErrs: 1,
wantWarnings: 0,
},
"missing measurements cause warnings": {
expected: M{
0: WithAllBytes(0x00, Enforce, PCRMeasurementLength),
1: WithAllBytes(0x11, Enforce, PCRMeasurementLength),
2: WithAllBytes(0x22, WarnOnly, PCRMeasurementLength),
},
actual: map[uint32][]byte{
0: bytes.Repeat([]byte{0x00}, PCRMeasurementLength),
1: bytes.Repeat([]byte{0x11}, PCRMeasurementLength),
},
wantErrs: 0,
wantWarnings: 1,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
gotWarnings, gotErrs := tc.expected.Compare(tc.actual)
assert.Equal(tc.wantErrs, len(gotErrs))
assert.Equal(tc.wantWarnings, len(gotWarnings))
})
}
}

View File

@ -7,9 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only
package tdx
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/attestation"
@ -81,13 +81,12 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte
}
// Verify the quote against the expected measurements.
for idx, ex := range v.expected {
if !bytes.Equal(ex.Expected, tdMeasure[idx]) {
if !ex.ValidationOpt {
return nil, fmt.Errorf("untrusted TD measurement value at index %d", idx)
}
v.log.Warnf("Encountered untrusted TD measurement value at index %d", idx)
}
warnings, errs := v.expected.Compare(tdMeasure)
for _, warning := range warnings {
v.log.Warnf(warning)
}
if len(errs) > 0 {
return nil, fmt.Errorf("measurement validation failed:\n%w", errors.Join(errs...))
}
return attDoc.UserData, nil

View File

@ -7,10 +7,10 @@ SPDX-License-Identifier: AGPL-3.0-only
package vtpm
import (
"bytes"
"context"
"crypto"
"encoding/json"
"errors"
"fmt"
"io"
@ -219,21 +219,12 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte
if err != nil {
return nil, err
}
for idx, pcr := range v.expected {
if !bytes.Equal(pcr.Expected[:], attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) {
if pcr.ValidationOpt == measurements.Enforce {
return nil, fmt.Errorf(
"untrusted PCR value %x at index %d",
attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx],
idx,
)
}
v.log.Warnf(
"Encountered untrusted PCR value %x at index %d",
attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx],
idx,
)
}
warnings, errs := v.expected.Compare(attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs)
for _, warning := range warnings {
v.log.Warnf(warning)
}
if len(errs) > 0 {
return nil, fmt.Errorf("measurement validation failed:\n%w", errors.Join(errs...))
}
v.log.Infof("Successfully validated attestation document")

View File

@ -80,7 +80,7 @@ func TestValidate(t *testing.T) {
defer tpmCloser.Close()
issuer := NewIssuer(tpmOpen, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo, logger.NewTest(t))
validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, nil)
validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, logger.NewTest(t))
nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation")
@ -206,6 +206,10 @@ func TestValidate(t *testing.T) {
Expected: []byte{0xFF},
ValidationOpt: measurements.Enforce,
},
1: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.Enforce,
},
},
fakeGetTrustedKey,
fakeValidateCVM,
@ -214,6 +218,25 @@ func TestValidate(t *testing.T) {
nonce: nonce,
wantErr: true,
},
"untrusted WarnOnly PCRs": {
validator: NewValidator(
measurements.M{
0: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.WarnOnly,
},
1: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.WarnOnly,
},
},
fakeGetTrustedKey,
fakeValidateCVM,
logger.NewTest(t)),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
wantErr: false,
},
"no sha256 quote": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{