diff --git a/internal/attestation/measurements/measurements.go b/internal/attestation/measurements/measurements.go index 7f5655146..a5bde8951 100644 --- a/internal/attestation/measurements/measurements.go +++ b/internal/attestation/measurements/measurements.go @@ -78,6 +78,42 @@ type ImageMeasurementsV2Entry struct { Measurements M `json:"measurements" yaml:"measurements"` } +// MergeImageMeasurementsV2 combines the image measurement entries from multiple sources into a single +// ImageMeasurementsV2 object. +func MergeImageMeasurementsV2(measurements ...ImageMeasurementsV2) (ImageMeasurementsV2, error) { + if len(measurements) == 0 { + return ImageMeasurementsV2{}, errors.New("no measurement objects specified") + } + if len(measurements) == 1 { + return measurements[0], nil + } + out := ImageMeasurementsV2{ + Version: measurements[0].Version, + Ref: measurements[0].Ref, + Stream: measurements[0].Stream, + List: []ImageMeasurementsV2Entry{}, + } + for _, m := range measurements { + if m.Version != out.Version { + return ImageMeasurementsV2{}, errors.New("version mismatch") + } + if m.Ref != out.Ref { + return ImageMeasurementsV2{}, errors.New("ref mismatch") + } + if m.Stream != out.Stream { + return ImageMeasurementsV2{}, errors.New("stream mismatch") + } + out.List = append(out.List, m.List...) + } + sort.SliceStable(out.List, func(i, j int) bool { + if out.List[i].CSP != out.List[j].CSP { + return out.List[i].CSP < out.List[j].CSP + } + return out.List[i].AttestationVariant < out.List[j].AttestationVariant + }) + return out, nil +} + // MarshalYAML returns the YAML encoding of m. func (m M) MarshalYAML() (any, error) { // cast to prevent infinite recursion diff --git a/internal/attestation/measurements/measurements_test.go b/internal/attestation/measurements/measurements_test.go index 1f1127c45..3a8d9e741 100644 --- a/internal/attestation/measurements/measurements_test.go +++ b/internal/attestation/measurements/measurements_test.go @@ -746,3 +746,200 @@ func TestEqualTo(t *testing.T) { }) } } + +func TestMergeImageMeasurementsV2(t *testing.T) { + testCases := map[string]struct { + measurements []ImageMeasurementsV2 + wantMeasurements ImageMeasurementsV2 + wantErr bool + }{ + "only one element": { + measurements: []ImageMeasurementsV2{ + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + wantMeasurements: ImageMeasurementsV2{ + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + "valid measurements": { + measurements: []ImageMeasurementsV2{ + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.GCP, + AttestationVariant: "gcp-sev-es", + Measurements: M{ + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + wantMeasurements: ImageMeasurementsV2{ + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + { + CSP: cloudprovider.GCP, + AttestationVariant: "gcp-sev-es", + Measurements: M{ + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + "sorting": { + measurements: []ImageMeasurementsV2{ + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.GCP, + AttestationVariant: "gcp-sev-es", + Measurements: M{ + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + wantMeasurements: ImageMeasurementsV2{ + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + { + CSP: cloudprovider.GCP, + AttestationVariant: "gcp-sev-es", + Measurements: M{ + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + "mismatch in base info": { + measurements: []ImageMeasurementsV2{ + { + Ref: "test-ref", + Stream: "nightly", + Version: "v1.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.AWS, + AttestationVariant: "aws-nitro-tpm", + Measurements: M{ + 0: WithAllBytes(0x00, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + { + Ref: "other-ref", + Stream: "stable", + Version: "v2.0.0", + List: []ImageMeasurementsV2Entry{ + { + CSP: cloudprovider.GCP, + AttestationVariant: "gcp-sev-es", + Measurements: M{ + 1: WithAllBytes(0x11, Enforce, PCRMeasurementLength), + }, + }, + }, + }, + }, + wantErr: true, + }, + "empty list": { + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + gotMeasurements, err := MergeImageMeasurementsV2(tc.measurements...) + + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tc.wantMeasurements, gotMeasurements) + }) + } +}