constellation/internal/attestation/measurements/measurements_test.go

404 lines
12 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package measurements
import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMarshalYAML(t *testing.T) {
testCases := map[string]struct {
measurements M
wantBase64Map map[uint32]string
}{
"valid measurements": {
measurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
wantBase64Map: map[uint32]string{
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
},
},
"omit bytes": {
measurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantBase64Map: map[uint32]string{
2: "",
3: "AQIDBA==",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
base64Map, err := tc.measurements.MarshalYAML()
require.NoError(err)
assert.Equal(tc.wantBase64Map, base64Map)
})
}
}
func TestUnmarshalYAML(t *testing.T) {
testCases := map[string]struct {
inputBase64Map map[uint32]string
forceUnmarshalError bool
wantMeasurements M
wantErr bool
}{
"valid measurements": {
inputBase64Map: map[uint32]string{
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
},
wantMeasurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
},
"empty bytes": {
inputBase64Map: map[uint32]string{
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: M{
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
},
"invalid base64": {
inputBase64Map: map[uint32]string{
2: "This is not base64",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantErr: true,
},
"simulated unmarshal error": {
inputBase64Map: map[uint32]string{
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
forceUnmarshalError: true,
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
var m M
2022-10-25 09:51:23 -04:00
err := m.UnmarshalYAML(func(i any) error {
if base64Map, ok := i.(map[uint32]string); ok {
for key, value := range tc.inputBase64Map {
base64Map[key] = value
}
}
if tc.forceUnmarshalError {
return errors.New("unmarshal error")
}
return nil
})
if tc.wantErr {
assert.Error(err)
} else {
require.NoError(err)
assert.Equal(tc.wantMeasurements, m)
}
})
}
}
func TestMeasurementsCopyFrom(t *testing.T) {
testCases := map[string]struct {
current M
newMeasurements M
wantMeasurements M
}{
"add to empty": {
current: M{},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
},
"keep existing": {
current: M{
4: PCRWithAllBytes(0x01),
5: PCRWithAllBytes(0x02),
},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
4: PCRWithAllBytes(0x01),
5: PCRWithAllBytes(0x02),
},
},
"overwrite existing": {
current: M{
2: PCRWithAllBytes(0x04),
3: PCRWithAllBytes(0x05),
},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.current.CopyFrom(tc.newMeasurements)
assert.Equal(tc.wantMeasurements, tc.current)
})
}
}
// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
func urlMustParse(raw string) *url.URL {
parsed, _ := url.Parse(raw)
return parsed
}
func TestMeasurementsFetchAndVerify(t *testing.T) {
testCases := map[string]struct {
measurements string
measurementsStatus int
signature string
signatureStatus int
publicKey []byte
wantMeasurements M
wantSHA string
wantError bool
}{
"simple": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{
0: PCRWithAllBytes(0x00),
},
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
},
"404 measurements": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusNotFound,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantError: true,
},
"404 signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusNotFound,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantError: true,
},
"broken signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusOK,
signature: "AAAAAAs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantError: true,
},
"not yaml": {
measurements: "This is some content to be signed!\n",
measurementsStatus: http.StatusOK,
signature: "MEUCIQDzMN3yaiO9sxLGAaSA9YD8rLwzvOaZKWa/bzkcjImUFAIgXLLGzClYUd1dGbuEiY3O/g/eiwQYlyxqLQalxjFmz+8=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAElWUhon39eAqzEC+/GP03oY4/MQg+\ngCDlEzkuOCybCHf+q766bve799L7Y5y5oRsHY1MrUCUwYF/tL7Sg7EYMsA==\n-----END PUBLIC KEY-----"),
wantError: true,
},
}
measurementsURL := urlMustParse("https://somesite.com/measurements.yaml")
signatureURL := urlMustParse("https://somesite.com/measurements.yaml.sig")
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == measurementsURL.String() {
return &http.Response{
StatusCode: tc.measurementsStatus,
Body: io.NopCloser(strings.NewReader(tc.measurements)),
Header: make(http.Header),
}
}
if req.URL.String() == signatureURL.String() {
return &http.Response{
StatusCode: tc.signatureStatus,
Body: io.NopCloser(strings.NewReader(tc.signature)),
Header: make(http.Header),
}
}
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(strings.NewReader("Not found.")),
Header: make(http.Header),
}
})
m := M{}
hash, err := m.FetchAndVerify(context.Background(), client, measurementsURL, signatureURL, tc.publicKey)
if tc.wantError {
assert.Error(err)
return
}
assert.Equal(tc.wantSHA, hash)
assert.NoError(err)
assert.EqualValues(tc.wantMeasurements, m)
})
}
}
func TestPCRWithAllBytes(t *testing.T) {
testCases := map[string]struct {
b byte
wantPCR []byte
}{
"0x00": {
b: 0x00,
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
"0x01": {
b: 0x01,
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
},
"0xFF": {
b: 0xFF,
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcr := PCRWithAllBytes(tc.b)
assert.Equal(tc.wantPCR, pcr)
})
}
}
func TestEqualTo(t *testing.T) {
testCases := map[string]struct {
given M
other M
wantEqual bool
}{
"same values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
wantEqual: true,
},
"different number of elements": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
"different values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0xFF),
1: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if tc.wantEqual {
assert.True(tc.given.EqualTo(tc.other))
} else {
assert.False(tc.given.EqualTo(tc.other))
}
})
}
}