diff --git a/measurement-reader/internal/sorted/sorted.go b/measurement-reader/internal/sorted/sorted.go index 523103913..637709bc0 100644 --- a/measurement-reader/internal/sorted/sorted.go +++ b/measurement-reader/internal/sorted/sorted.go @@ -4,11 +4,67 @@ Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ -// Type definition for sorted measurements. +// Package sorted defines a type for print-friendly sorted measurements and allows sorting TPM and TDX measurements. package sorted +import ( + "fmt" + "sort" + + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" +) + // Measurement wraps a measurement custom index and value. type Measurement struct { Index string Value []byte } + +// MeasurementType are the supported attestation types we can sort. +type MeasurementType uint32 + +const ( + TPM MeasurementType = iota + TDX +) + +// SortMeasurements returns the sorted measurements for either TPM or TDX measurements. +func SortMeasurements(m measurements.M, measurementType MeasurementType) []Measurement { + if measurementType != TPM && measurementType != TDX { + return nil + } + + keys := make([]uint32, 0, len(m)) + for idx := range m { + keys = append(keys, idx) + } + sort.Slice(keys, func(i, j int) bool { + return keys[i] < keys[j] + }) + + var sortedMeasurements []Measurement + + for _, idx := range keys { + var index string + switch measurementType { + case TPM: + index = fmt.Sprintf("PCR[%02d]", idx) + case TDX: + // idx 0 is MRTD + if idx == 0 { + index = "MRTD" + break + } + // RTMR 0 starts at idx 1, so we have to subtract by one here. + index = fmt.Sprintf("RTMR[%01d]", idx-1) + } + + expected := m[idx].Expected + sortedMeasurements = append(sortedMeasurements, Measurement{ + Index: index, + Value: expected[:], + }) + } + + return sortedMeasurements +} diff --git a/measurement-reader/internal/tpm/tpm_test.go b/measurement-reader/internal/sorted/sorted_test.go similarity index 53% rename from measurement-reader/internal/tpm/tpm_test.go rename to measurement-reader/internal/sorted/sorted_test.go index d44ddcbe1..47527e533 100644 --- a/measurement-reader/internal/tpm/tpm_test.go +++ b/measurement-reader/internal/sorted/sorted_test.go @@ -4,29 +4,30 @@ Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ -package tpm +package sorted import ( "bytes" "testing" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" - "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" "github.com/stretchr/testify/assert" ) func TestSortMeasurements(t *testing.T) { testCases := map[string]struct { - input measurements.M - want []sorted.Measurement + measurementType MeasurementType + input measurements.M + want []Measurement }{ - "pre sorted": { + "pre sorted TPM": { + measurementType: TPM, input: measurements.M{ 0: measurements.WithAllBytes(0x11, measurements.Enforce), 1: measurements.WithAllBytes(0x22, measurements.Enforce), 2: measurements.WithAllBytes(0x33, measurements.Enforce), }, - want: []sorted.Measurement{ + want: []Measurement{ { Index: "PCR[00]", Value: bytes.Repeat([]byte{0x11}, 32), @@ -41,13 +42,14 @@ func TestSortMeasurements(t *testing.T) { }, }, }, - "unsorted": { + "unsorted TPM": { + measurementType: TPM, input: measurements.M{ 1: measurements.WithAllBytes(0x22, measurements.Enforce), 0: measurements.WithAllBytes(0x11, measurements.Enforce), 2: measurements.WithAllBytes(0x33, measurements.Enforce), }, - want: []sorted.Measurement{ + want: []Measurement{ { Index: "PCR[00]", Value: bytes.Repeat([]byte{0x11}, 32), @@ -62,13 +64,57 @@ func TestSortMeasurements(t *testing.T) { }, }, }, + "pre sorted TDX": { + measurementType: TDX, + input: measurements.M{ + 0: measurements.WithAllBytes(0x11, false), + 1: measurements.WithAllBytes(0x22, false), + 2: measurements.WithAllBytes(0x33, false), + }, + want: []Measurement{ + { + Index: "MRTD", + Value: bytes.Repeat([]byte{0x11}, 32), + }, + { + Index: "RTMR[0]", + Value: bytes.Repeat([]byte{0x22}, 32), + }, + { + Index: "RTMR[1]", + Value: bytes.Repeat([]byte{0x33}, 32), + }, + }, + }, + "unsorted TDX": { + measurementType: TDX, + input: measurements.M{ + 1: measurements.WithAllBytes(0x22, false), + 0: measurements.WithAllBytes(0x11, false), + 2: measurements.WithAllBytes(0x33, false), + }, + want: []Measurement{ + { + Index: "MRTD", + Value: bytes.Repeat([]byte{0x11}, 32), + }, + { + Index: "RTMR[0]", + Value: bytes.Repeat([]byte{0x22}, 32), + }, + { + Index: "RTMR[1]", + Value: bytes.Repeat([]byte{0x33}, 32), + }, + }, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) - got := sortMeasurements(tc.input) + got := SortMeasurements(tc.input, tc.measurementType) for i := range got { assert.Equal(got[i].Index, tc.want[i].Index) assert.Equal(got[i].Value, tc.want[i].Value) diff --git a/measurement-reader/internal/tdx/tdx.go b/measurement-reader/internal/tdx/tdx.go index 403d71e16..9c90aa8df 100644 --- a/measurement-reader/internal/tdx/tdx.go +++ b/measurement-reader/internal/tdx/tdx.go @@ -8,10 +8,6 @@ SPDX-License-Identifier: AGPL-3.0-only package tdx import ( - "fmt" - "sort" - - "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/tdx" "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" ) @@ -23,36 +19,5 @@ func Measurements() ([]sorted.Measurement, error) { return nil, err } - return sortMeasurements(m), nil -} - -func sortMeasurements(m measurements.M) []sorted.Measurement { - keys := make([]uint32, 0, len(m)) - for idx := range m { - keys = append(keys, idx) - } - sort.Slice(keys, func(i, j int) bool { - return keys[i] < keys[j] - }) - - var measurements []sorted.Measurement - for _, idx := range keys { - expected := m[idx].Expected - - // Index 0 == MRTD - // Index 1-5 == RTMR[0-4] - var index string - if (idx) == 0 { - index = "MRTD" - } else { - index = fmt.Sprintf("RTMR[%01d]", idx-1) - } - - measurements = append(measurements, sorted.Measurement{ - Index: index, - Value: expected[:], - }) - } - - return measurements + return sorted.SortMeasurements(m, sorted.TDX), nil } diff --git a/measurement-reader/internal/tdx/tdx_test.go b/measurement-reader/internal/tdx/tdx_test.go deleted file mode 100644 index d12ef4168..000000000 --- a/measurement-reader/internal/tdx/tdx_test.go +++ /dev/null @@ -1,78 +0,0 @@ -/* -Copyright (c) Edgeless Systems GmbH - -SPDX-License-Identifier: AGPL-3.0-only -*/ - -package tdx - -import ( - "bytes" - "testing" - - "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" - "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" - "github.com/stretchr/testify/assert" -) - -func TestSortMeasurements(t *testing.T) { - testCases := map[string]struct { - input measurements.M - want []sorted.Measurement - }{ - "pre sorted": { - input: measurements.M{ - 0: measurements.WithAllBytes(0x11, false), - 1: measurements.WithAllBytes(0x22, false), - 2: measurements.WithAllBytes(0x33, false), - }, - want: []sorted.Measurement{ - { - Index: "MRTD", - Value: bytes.Repeat([]byte{0x11}, 32), - }, - { - Index: "RTMR[0]", - Value: bytes.Repeat([]byte{0x22}, 32), - }, - { - Index: "RTMR[1]", - Value: bytes.Repeat([]byte{0x33}, 32), - }, - }, - }, - "unsorted": { - input: measurements.M{ - 1: measurements.WithAllBytes(0x22, false), - 0: measurements.WithAllBytes(0x11, false), - 2: measurements.WithAllBytes(0x33, false), - }, - want: []sorted.Measurement{ - { - Index: "MRTD", - Value: bytes.Repeat([]byte{0x11}, 32), - }, - { - Index: "RTMR[0]", - Value: bytes.Repeat([]byte{0x22}, 32), - }, - { - Index: "RTMR[1]", - Value: bytes.Repeat([]byte{0x33}, 32), - }, - }, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - got := sortMeasurements(tc.input) - for i := range got { - assert.Equal(got[i].Index, tc.want[i].Index) - assert.Equal(got[i].Value, tc.want[i].Value) - } - }) - } -} diff --git a/measurement-reader/internal/tpm/tpm.go b/measurement-reader/internal/tpm/tpm.go index 4bfe91754..628ae6502 100644 --- a/measurement-reader/internal/tpm/tpm.go +++ b/measurement-reader/internal/tpm/tpm.go @@ -8,10 +8,6 @@ SPDX-License-Identifier: AGPL-3.0-only package tpm import ( - "fmt" - "sort" - - "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/measurement-reader/internal/sorted" tpmClient "github.com/google/go-tpm-tools/client" @@ -25,26 +21,5 @@ func Measurements() ([]sorted.Measurement, error) { return nil, err } - return sortMeasurements(m), nil -} - -func sortMeasurements(m measurements.M) []sorted.Measurement { - keys := make([]uint32, 0, len(m)) - for idx := range m { - keys = append(keys, idx) - } - sort.Slice(keys, func(i, j int) bool { - return keys[i] < keys[j] - }) - - var measurements []sorted.Measurement - for _, idx := range keys { - expected := m[idx].Expected - measurements = append(measurements, sorted.Measurement{ - Index: fmt.Sprintf("PCR[%02d]", idx), - Value: expected[:], - }) - } - - return measurements + return sorted.SortMeasurements(m, sorted.TPM), nil }