constellation/internal/attestation/vtpm/attestation_test.go
2024-02-08 14:20:01 +00:00

491 lines
14 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package vtpm
import (
"context"
"crypto"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"testing"
tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/edgelesssys/constellation/v2/internal/attestation/initialize"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/logger"
)
type simTPMWithEventLog struct {
io.ReadWriteCloser
}
func newSimTPMWithEventLog() (io.ReadWriteCloser, error) {
tpmSim, err := simulator.OpenSimulatedTPM()
if err != nil {
return nil, err
}
return &simTPMWithEventLog{tpmSim}, nil
}
// EventLog overrides the default event log getter.
func (s simTPMWithEventLog) EventLog() ([]byte, error) {
// event log header for successful parsing of event log
header := []byte{
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x53, 0x70, 0x65,
0x63, 0x20, 0x49, 0x44, 0x20, 0x45, 0x76, 0x65, 0x6E, 0x74, 0x30, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x02, 0x00, 0x02, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x14, 0x00, 0x0B, 0x00, 0x20, 0x00, 0x00,
}
return header, nil
}
func fakeGetInstanceInfo(_ context.Context, _ io.ReadWriteCloser, _ []byte) ([]byte, error) {
return []byte("unit-test"), nil
}
func TestValidate(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
require := require.New(t)
fakeValidateCVM := func(AttestationDocument, *attest.MachineState) error { return nil }
fakeGetTrustedKey := func(_ context.Context, attDoc AttestationDocument, _ []byte) (crypto.PublicKey, error) {
pubArea, err := tpm2.DecodePublic(attDoc.Attestation.AkPub)
if err != nil {
return nil, err
}
return pubArea.Key()
}
testExpectedPCRs := measurements.M{
0: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
1: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
uint32(measurements.PCRIndexClusterID): measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength),
}
warnLog := &testAttestationLogger{}
tpmOpen, tpmCloser := tpmsim.NewSimulatedTPMOpenFunc()
defer tpmCloser.Close()
issuer := NewIssuer(tpmOpen, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo, logger.NewTest(t))
validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, logger.NewTest(t))
nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation")
ctx := context.Background()
attDocRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err)
var attDoc AttestationDocument
err = json.Unmarshal(attDocRaw, &attDoc)
require.NoError(err)
require.Equal(challenge, attDoc.UserData)
// valid test
out, err := validator.Validate(ctx, attDocRaw, nonce)
require.NoError(err)
require.Equal(challenge, out)
// validation must fail after bootstrapping (change of enforced PCR)
require.NoError(initialize.MarkNodeAsBootstrapped(tpmOpen, []byte{2}))
attDocBootstrappedRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err)
_, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err)
// userData must be bound to PCR state
attDocBootstrappedRaw, err = issuer.Issue(ctx, []byte{2, 3}, nonce)
require.NoError(err)
var attDocBootstrapped AttestationDocument
require.NoError(json.Unmarshal(attDocBootstrappedRaw, &attDocBootstrapped))
attDocBootstrapped.Attestation = attDoc.Attestation
attDocBootstrappedRaw, err = json.Marshal(attDocBootstrapped)
require.NoError(err)
_, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err)
expectedPCRs := measurements.M{
0: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength),
1: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength),
2: measurements.Measurement{
Expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20},
ValidationOpt: measurements.WarnOnly,
},
3: measurements.Measurement{
Expected: []byte{0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40},
ValidationOpt: measurements.WarnOnly,
},
4: measurements.Measurement{
Expected: []byte{0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60},
ValidationOpt: measurements.WarnOnly,
},
5: measurements.Measurement{
Expected: []byte{0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80},
ValidationOpt: measurements.WarnOnly,
},
}
warningValidator := NewValidator(
expectedPCRs,
fakeGetTrustedKey,
fakeValidateCVM,
warnLog,
)
out, err = warningValidator.Validate(ctx, attDocRaw, nonce)
require.NoError(err)
assert.Equal(t, challenge, out)
assert.Len(t, warnLog.warnings, 4)
testCases := map[string]struct {
validator *Validator
attDoc []byte
nonce []byte
wantErr bool
}{
"valid": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
},
"invalid nonce": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: []byte{4, 3, 2, 1},
wantErr: true,
},
"invalid signature": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: attDoc.Attestation,
InstanceInfo: attDoc.InstanceInfo,
UserData: []byte("wrong data"),
}, require),
nonce: nonce,
wantErr: true,
},
"untrusted attestation public key": {
validator: NewValidator(
testExpectedPCRs,
func(context.Context, AttestationDocument, []byte) (crypto.PublicKey, error) {
return nil, errors.New("untrusted")
},
fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
wantErr: true,
},
"not a CVM": {
validator: NewValidator(
testExpectedPCRs,
fakeGetTrustedKey,
func(AttestationDocument, *attest.MachineState) error {
return errors.New("untrusted")
},
warnLog),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
wantErr: true,
},
"untrusted PCRs": {
validator: NewValidator(
measurements.M{
0: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.Enforce,
},
1: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.Enforce,
},
},
fakeGetTrustedKey,
fakeValidateCVM,
warnLog),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
wantErr: true,
},
"untrusted WarnOnly PCRs": {
validator: NewValidator(
measurements.M{
0: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.WarnOnly,
},
1: measurements.Measurement{
Expected: []byte{0xFF},
ValidationOpt: measurements.WarnOnly,
},
},
fakeGetTrustedKey,
fakeValidateCVM,
logger.NewTest(t)),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
wantErr: false,
},
"no sha256 quote": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: &attest.Attestation{
AkPub: attDoc.Attestation.AkPub,
Quotes: []*tpm.Quote{
attDoc.Attestation.Quotes[2],
},
EventLog: attDoc.Attestation.EventLog,
InstanceInfo: attDoc.Attestation.InstanceInfo,
},
InstanceInfo: attDoc.InstanceInfo,
UserData: attDoc.UserData,
}, require),
nonce: nonce,
wantErr: true,
},
"invalid attestation document": {
validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, warnLog),
attDoc: []byte("invalid attestation"),
nonce: nonce,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
_, err = tc.validator.Validate(ctx, tc.attDoc, tc.nonce)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func mustMarshalAttestation(attDoc AttestationDocument, require *require.Assertions) []byte {
out, err := json.Marshal(attDoc)
require.NoError(err)
return out
}
func TestFailIssuer(t *testing.T) {
testCases := map[string]struct {
issuer *Issuer
userData []byte
nonce []byte
}{
"fail openTPM": {
issuer: NewIssuer(
func() (io.ReadWriteCloser, error) {
return nil, errors.New("failure")
},
tpmclient.AttestationKeyRSA,
fakeGetInstanceInfo,
nil,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail getAttestationKey": {
issuer: NewIssuer(
newSimTPMWithEventLog,
func(tpm io.ReadWriter) (*tpmclient.Key, error) {
return nil, errors.New("failure")
},
fakeGetInstanceInfo,
nil,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail Attest": {
issuer: NewIssuer(
newSimTPMWithEventLog,
func(tpm io.ReadWriter) (*tpmclient.Key, error) {
return &tpmclient.Key{}, nil
},
fakeGetInstanceInfo,
nil,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail getInstanceInfo": {
issuer: NewIssuer(
newSimTPMWithEventLog,
tpmclient.AttestationKeyRSA,
func(context.Context, io.ReadWriteCloser, []byte) ([]byte, error) { return nil, errors.New("failure") },
nil,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.issuer.log = logger.NewTest(t)
_, err := tc.issuer.Issue(context.Background(), tc.userData, tc.nonce)
assert.Error(err)
})
}
}
func TestGetSHA256QuoteIndex(t *testing.T) {
testCases := map[string]struct {
quotes []*tpm.Quote
wantIdx int
wantErr bool
}{
"idx 0 is valid": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
},
},
},
wantIdx: 0,
wantErr: false,
},
"idx 1 is valid": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA1,
},
},
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
},
},
},
wantIdx: 1,
wantErr: false,
},
"no quotes": {
quotes: nil,
wantErr: true,
},
"quotes is nil": {
quotes: make([]*tpm.Quote, 2),
wantErr: true,
},
"pcrs is nil": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA1,
},
},
{
Pcrs: nil,
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
idx, err := GetSHA256QuoteIndex(tc.quotes)
if tc.wantErr {
assert.Error(err)
assert.Equal(0, idx)
} else {
assert.NoError(err)
assert.Equal(tc.wantIdx, idx)
}
})
}
}
func TestGetSelectedMeasurements(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
testCases := map[string]struct {
openFunc TPMOpenFunc
pcrSelection tpm2.PCRSelection
wantErr bool
}{
"error": {
openFunc: func() (io.ReadWriteCloser, error) { return nil, errors.New("error") },
pcrSelection: tpm2.PCRSelection{
Hash: tpm2.AlgSHA256,
PCRs: []int{0, 1, 2},
},
wantErr: true,
},
"3 PCRs": {
openFunc: simulator.OpenSimulatedTPM,
pcrSelection: tpm2.PCRSelection{
Hash: tpm2.AlgSHA256,
PCRs: []int{0, 1, 2},
},
},
"Azure PCRS": {
openFunc: simulator.OpenSimulatedTPM,
pcrSelection: AzurePCRSelection,
},
"GCP PCRs": {
openFunc: simulator.OpenSimulatedTPM,
pcrSelection: GCPPCRSelection,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
pcrs, err := GetSelectedMeasurements(tc.openFunc, tc.pcrSelection)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Len(pcrs, len(tc.pcrSelection.PCRs))
})
}
}
type testAttestationLogger struct {
infos []string
warnings []string
}
func (w *testAttestationLogger) Info(format string, args ...any) {
w.infos = append(w.infos, fmt.Sprintf(format, args...))
}
func (w *testAttestationLogger) Warn(format string, args ...any) {
w.warnings = append(w.warnings, fmt.Sprintf(format, args...))
}