measurement-reader: unify TPM & TDX sorting

This commit is contained in:
Nils Hanke 2023-03-09 18:55:05 +01:00 committed by Malte Poll
parent 253d201ff3
commit d58b5f1c06
5 changed files with 114 additions and 150 deletions

View file

@ -4,11 +4,67 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Type definition for sorted measurements.
// Package sorted defines a type for print-friendly sorted measurements and allows sorting TPM and TDX measurements.
package sorted
import (
"fmt"
"sort"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
)
// Measurement wraps a measurement custom index and value.
type Measurement struct {
Index string
Value []byte
}
// MeasurementType are the supported attestation types we can sort.
type MeasurementType uint32
const (
TPM MeasurementType = iota
TDX
)
// SortMeasurements returns the sorted measurements for either TPM or TDX measurements.
func SortMeasurements(m measurements.M, measurementType MeasurementType) []Measurement {
if measurementType != TPM && measurementType != TDX {
return nil
}
keys := make([]uint32, 0, len(m))
for idx := range m {
keys = append(keys, idx)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
var sortedMeasurements []Measurement
for _, idx := range keys {
var index string
switch measurementType {
case TPM:
index = fmt.Sprintf("PCR[%02d]", idx)
case TDX:
// idx 0 is MRTD
if idx == 0 {
index = "MRTD"
break
}
// RTMR 0 starts at idx 1, so we have to subtract by one here.
index = fmt.Sprintf("RTMR[%01d]", idx-1)
}
expected := m[idx].Expected
sortedMeasurements = append(sortedMeasurements, Measurement{
Index: index,
Value: expected[:],
})
}
return sortedMeasurements
}

View file

@ -0,0 +1,124 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package sorted
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/stretchr/testify/assert"
)
func TestSortMeasurements(t *testing.T) {
testCases := map[string]struct {
measurementType MeasurementType
input measurements.M
want []Measurement
}{
"pre sorted TPM": {
measurementType: TPM,
input: measurements.M{
0: measurements.WithAllBytes(0x11, measurements.Enforce),
1: measurements.WithAllBytes(0x22, measurements.Enforce),
2: measurements.WithAllBytes(0x33, measurements.Enforce),
},
want: []Measurement{
{
Index: "PCR[00]",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "PCR[01]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "PCR[02]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
"unsorted TPM": {
measurementType: TPM,
input: measurements.M{
1: measurements.WithAllBytes(0x22, measurements.Enforce),
0: measurements.WithAllBytes(0x11, measurements.Enforce),
2: measurements.WithAllBytes(0x33, measurements.Enforce),
},
want: []Measurement{
{
Index: "PCR[00]",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "PCR[01]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "PCR[02]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
"pre sorted TDX": {
measurementType: TDX,
input: measurements.M{
0: measurements.WithAllBytes(0x11, false),
1: measurements.WithAllBytes(0x22, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
"unsorted TDX": {
measurementType: TDX,
input: measurements.M{
1: measurements.WithAllBytes(0x22, false),
0: measurements.WithAllBytes(0x11, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
got := SortMeasurements(tc.input, tc.measurementType)
for i := range got {
assert.Equal(got[i].Index, tc.want[i].Index)
assert.Equal(got[i].Value, tc.want[i].Value)
}
})
}
}