measurement-reader: unify TPM & TDX sorting

This commit is contained in:
Nils Hanke 2023-03-09 18:55:05 +01:00 committed by Malte Poll
parent 253d201ff3
commit d58b5f1c06
5 changed files with 114 additions and 150 deletions

View File

@ -4,11 +4,67 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only SPDX-License-Identifier: AGPL-3.0-only
*/ */
// Type definition for sorted measurements. // Package sorted defines a type for print-friendly sorted measurements and allows sorting TPM and TDX measurements.
package sorted package sorted
import (
"fmt"
"sort"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
)
// Measurement wraps a measurement custom index and value. // Measurement wraps a measurement custom index and value.
type Measurement struct { type Measurement struct {
Index string Index string
Value []byte Value []byte
} }
// MeasurementType are the supported attestation types we can sort.
type MeasurementType uint32
const (
TPM MeasurementType = iota
TDX
)
// SortMeasurements returns the sorted measurements for either TPM or TDX measurements.
func SortMeasurements(m measurements.M, measurementType MeasurementType) []Measurement {
if measurementType != TPM && measurementType != TDX {
return nil
}
keys := make([]uint32, 0, len(m))
for idx := range m {
keys = append(keys, idx)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
var sortedMeasurements []Measurement
for _, idx := range keys {
var index string
switch measurementType {
case TPM:
index = fmt.Sprintf("PCR[%02d]", idx)
case TDX:
// idx 0 is MRTD
if idx == 0 {
index = "MRTD"
break
}
// RTMR 0 starts at idx 1, so we have to subtract by one here.
index = fmt.Sprintf("RTMR[%01d]", idx-1)
}
expected := m[idx].Expected
sortedMeasurements = append(sortedMeasurements, Measurement{
Index: index,
Value: expected[:],
})
}
return sortedMeasurements
}

View File

@ -4,29 +4,30 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only SPDX-License-Identifier: AGPL-3.0-only
*/ */
package tpm package sorted
import ( import (
"bytes" "bytes"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestSortMeasurements(t *testing.T) { func TestSortMeasurements(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
input measurements.M measurementType MeasurementType
want []sorted.Measurement input measurements.M
want []Measurement
}{ }{
"pre sorted": { "pre sorted TPM": {
measurementType: TPM,
input: measurements.M{ input: measurements.M{
0: measurements.WithAllBytes(0x11, measurements.Enforce), 0: measurements.WithAllBytes(0x11, measurements.Enforce),
1: measurements.WithAllBytes(0x22, measurements.Enforce), 1: measurements.WithAllBytes(0x22, measurements.Enforce),
2: measurements.WithAllBytes(0x33, measurements.Enforce), 2: measurements.WithAllBytes(0x33, measurements.Enforce),
}, },
want: []sorted.Measurement{ want: []Measurement{
{ {
Index: "PCR[00]", Index: "PCR[00]",
Value: bytes.Repeat([]byte{0x11}, 32), Value: bytes.Repeat([]byte{0x11}, 32),
@ -41,13 +42,14 @@ func TestSortMeasurements(t *testing.T) {
}, },
}, },
}, },
"unsorted": { "unsorted TPM": {
measurementType: TPM,
input: measurements.M{ input: measurements.M{
1: measurements.WithAllBytes(0x22, measurements.Enforce), 1: measurements.WithAllBytes(0x22, measurements.Enforce),
0: measurements.WithAllBytes(0x11, measurements.Enforce), 0: measurements.WithAllBytes(0x11, measurements.Enforce),
2: measurements.WithAllBytes(0x33, measurements.Enforce), 2: measurements.WithAllBytes(0x33, measurements.Enforce),
}, },
want: []sorted.Measurement{ want: []Measurement{
{ {
Index: "PCR[00]", Index: "PCR[00]",
Value: bytes.Repeat([]byte{0x11}, 32), Value: bytes.Repeat([]byte{0x11}, 32),
@ -62,13 +64,57 @@ func TestSortMeasurements(t *testing.T) {
}, },
}, },
}, },
"pre sorted TDX": {
measurementType: TDX,
input: measurements.M{
0: measurements.WithAllBytes(0x11, false),
1: measurements.WithAllBytes(0x22, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
"unsorted TDX": {
measurementType: TDX,
input: measurements.M{
1: measurements.WithAllBytes(0x22, false),
0: measurements.WithAllBytes(0x11, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
got := sortMeasurements(tc.input) got := SortMeasurements(tc.input, tc.measurementType)
for i := range got { for i := range got {
assert.Equal(got[i].Index, tc.want[i].Index) assert.Equal(got[i].Index, tc.want[i].Index)
assert.Equal(got[i].Value, tc.want[i].Value) assert.Equal(got[i].Value, tc.want[i].Value)

View File

@ -8,10 +8,6 @@ SPDX-License-Identifier: AGPL-3.0-only
package tdx package tdx
import ( import (
"fmt"
"sort"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/tdx" "github.com/edgelesssys/constellation/v2/internal/attestation/tdx"
"github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted"
) )
@ -23,36 +19,5 @@ func Measurements() ([]sorted.Measurement, error) {
return nil, err return nil, err
} }
return sortMeasurements(m), nil return sorted.SortMeasurements(m, sorted.TDX), nil
}
func sortMeasurements(m measurements.M) []sorted.Measurement {
keys := make([]uint32, 0, len(m))
for idx := range m {
keys = append(keys, idx)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
var measurements []sorted.Measurement
for _, idx := range keys {
expected := m[idx].Expected
// Index 0 == MRTD
// Index 1-5 == RTMR[0-4]
var index string
if (idx) == 0 {
index = "MRTD"
} else {
index = fmt.Sprintf("RTMR[%01d]", idx-1)
}
measurements = append(measurements, sorted.Measurement{
Index: index,
Value: expected[:],
})
}
return measurements
} }

View File

@ -1,78 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package tdx
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted"
"github.com/stretchr/testify/assert"
)
func TestSortMeasurements(t *testing.T) {
testCases := map[string]struct {
input measurements.M
want []sorted.Measurement
}{
"pre sorted": {
input: measurements.M{
0: measurements.WithAllBytes(0x11, false),
1: measurements.WithAllBytes(0x22, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []sorted.Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
"unsorted": {
input: measurements.M{
1: measurements.WithAllBytes(0x22, false),
0: measurements.WithAllBytes(0x11, false),
2: measurements.WithAllBytes(0x33, false),
},
want: []sorted.Measurement{
{
Index: "MRTD",
Value: bytes.Repeat([]byte{0x11}, 32),
},
{
Index: "RTMR[0]",
Value: bytes.Repeat([]byte{0x22}, 32),
},
{
Index: "RTMR[1]",
Value: bytes.Repeat([]byte{0x33}, 32),
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
got := sortMeasurements(tc.input)
for i := range got {
assert.Equal(got[i].Index, tc.want[i].Index)
assert.Equal(got[i].Value, tc.want[i].Value)
}
})
}
}

View File

@ -8,10 +8,6 @@ SPDX-License-Identifier: AGPL-3.0-only
package tpm package tpm
import ( import (
"fmt"
"sort"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted"
tpmClient "github.com/google/go-tpm-tools/client" tpmClient "github.com/google/go-tpm-tools/client"
@ -25,26 +21,5 @@ func Measurements() ([]sorted.Measurement, error) {
return nil, err return nil, err
} }
return sortMeasurements(m), nil return sorted.SortMeasurements(m, sorted.TPM), nil
}
func sortMeasurements(m measurements.M) []sorted.Measurement {
keys := make([]uint32, 0, len(m))
for idx := range m {
keys = append(keys, idx)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
var measurements []sorted.Measurement
for _, idx := range keys {
expected := m[idx].Expected
measurements = append(measurements, sorted.Measurement{
Index: fmt.Sprintf("PCR[%02d]", idx),
Value: expected[:],
})
}
return measurements
} }