370 lines
9.2 KiB
Go
Raw Normal View History

package vtpm
import (
"crypto"
"encoding/json"
"errors"
"io"
"testing"
tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm"
"github.com/google/go-tpm-tools/simulator"
"github.com/google/go-tpm/tpm2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type simTPMWithEventLog struct {
*simulator.Simulator
}
func newSimTPMWithEventLog() (io.ReadWriteCloser, error) {
tpmSim, err := simulator.Get()
if err != nil {
return nil, err
}
return &simTPMWithEventLog{tpmSim}, nil
}
// EventLog overrides the default event log getter.
func (s simTPMWithEventLog) EventLog() ([]byte, error) {
// event log header for successful parsing of event log
header := []byte{
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x53, 0x70, 0x65,
0x63, 0x20, 0x49, 0x44, 0x20, 0x45, 0x76, 0x65, 0x6E, 0x74, 0x30, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x02, 0x00, 0x02, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x14, 0x00, 0x0B, 0x00, 0x20, 0x00, 0x00,
}
return header, nil
}
func fakeGetInstanceInfo(tpm io.ReadWriteCloser) ([]byte, error) {
return []byte("unit-test"), nil
}
func fakeGetTrustedKey(aKPub, instanceInfo []byte) (crypto.PublicKey, error) {
pubArea, err := tpm2.DecodePublic(aKPub)
if err != nil {
return nil, err
}
return pubArea.Key()
}
func fakeValidateCVM(AttestationDocument) error { return nil }
var fakeTrustedPcrs = map[uint32][]byte{
0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
}
func TestValidate(t *testing.T) {
require := require.New(t)
issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo)
validator := NewValidator(fakeTrustedPcrs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15)
nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation")
attDocRaw, err := issuer.Issue(challenge, nonce)
require.NoError(err)
var attDoc AttestationDocument
err = json.Unmarshal(attDocRaw, &attDoc)
require.NoError(err)
require.Equal(challenge, attDoc.UserData)
// valid test
out, err := validator.Validate(attDocRaw, nonce)
require.NoError(err)
require.Equal(challenge, out)
testCases := map[string]struct {
validator *Validator
attDoc []byte
nonce []byte
errExpected bool
}{
"invalid nonce": {
validator: NewValidator(fakeTrustedPcrs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: []byte{4, 3, 2, 1},
errExpected: true,
},
"invalid signature": {
validator: NewValidator(fakeTrustedPcrs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15),
attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: attDoc.Attestation,
InstanceInfo: attDoc.InstanceInfo,
UserData: []byte("wrong data"),
UserDataSignature: attDoc.UserDataSignature,
}, require),
nonce: nonce,
errExpected: true,
},
"untrusted attestation public key": {
validator: NewValidator(
fakeTrustedPcrs,
func(akPub, instanceInfo []byte) (crypto.PublicKey, error) {
return nil, errors.New("untrusted")
},
fakeValidateCVM, VerifyPKCS1v15),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
errExpected: true,
},
"not a CVM": {
validator: NewValidator(
fakeTrustedPcrs,
fakeGetTrustedKey,
func(attestation AttestationDocument) error {
return errors.New("untrusted")
},
VerifyPKCS1v15),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
errExpected: true,
},
"untrusted PCRs": {
validator: NewValidator(
map[uint32][]byte{
0: {0xFF},
},
fakeGetTrustedKey,
fakeValidateCVM,
VerifyPKCS1v15),
attDoc: mustMarshalAttestation(attDoc, require),
nonce: nonce,
errExpected: true,
},
"no sha256 quote": {
validator: NewValidator(fakeTrustedPcrs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15),
attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: &attest.Attestation{
AkPub: attDoc.Attestation.AkPub,
Quotes: []*tpm.Quote{
attDoc.Attestation.Quotes[2],
},
EventLog: attDoc.Attestation.EventLog,
InstanceInfo: attDoc.Attestation.InstanceInfo,
},
InstanceInfo: attDoc.InstanceInfo,
UserData: attDoc.UserData,
UserDataSignature: attDoc.UserDataSignature,
}, require),
nonce: nonce,
errExpected: true,
},
"invalid attestation document": {
validator: NewValidator(fakeTrustedPcrs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15),
attDoc: []byte("invalid attestation"),
nonce: nonce,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
_, err = tc.validator.Validate(tc.attDoc, tc.nonce)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func mustMarshalAttestation(attDoc AttestationDocument, require *require.Assertions) []byte {
out, err := json.Marshal(attDoc)
require.NoError(err)
return out
}
func TestFailIssuer(t *testing.T) {
testCases := map[string]struct {
issuer *Issuer
userData []byte
nonce []byte
}{
"fail openTPM": {
issuer: NewIssuer(
func() (io.ReadWriteCloser, error) {
return nil, errors.New("failure")
},
tpmclient.AttestationKeyRSA,
fakeGetInstanceInfo,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail getAttestationKey": {
issuer: NewIssuer(
newSimTPMWithEventLog,
func(tpm io.ReadWriter) (*tpmclient.Key, error) {
return nil, errors.New("failure")
},
fakeGetInstanceInfo,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail Attest": {
issuer: NewIssuer(
newSimTPMWithEventLog,
func(tpm io.ReadWriter) (*tpmclient.Key, error) {
return &tpmclient.Key{}, nil
},
fakeGetInstanceInfo,
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
"fail getInstanceInfo": {
issuer: NewIssuer(
newSimTPMWithEventLog,
tpmclient.AttestationKeyRSA,
func(io.ReadWriteCloser) ([]byte, error) { return nil, errors.New("failure") },
),
userData: []byte("Constellation"),
nonce: []byte{1, 2, 3, 4},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
_, err := tc.issuer.Issue(tc.userData, tc.nonce)
assert.Error(err)
})
}
}
func TestGetSHA256QuoteIndex(t *testing.T) {
testCases := map[string]struct {
quotes []*tpm.Quote
idxExpected int
errExpected bool
}{
"idx 0 is valid": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
},
},
},
idxExpected: 0,
errExpected: false,
},
"idx 1 is valid": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA1,
},
},
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
},
},
},
idxExpected: 1,
errExpected: false,
},
"no quotes": {
quotes: nil,
errExpected: true,
},
"quotes is nil": {
quotes: make([]*tpm.Quote, 2),
errExpected: true,
},
"pcrs is nil": {
quotes: []*tpm.Quote{
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA1,
},
},
{
Pcrs: nil,
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
idx, err := GetSHA256QuoteIndex(tc.quotes)
if tc.errExpected {
assert.Error(err)
assert.Equal(0, idx)
} else {
assert.NoError(err)
assert.Equal(tc.idxExpected, idx)
}
})
}
}
func TestGetSelectedPCRs(t *testing.T) {
testCases := map[string]struct {
openFunc TPMOpenFunc
pcrSelection tpm2.PCRSelection
errExpected bool
}{
"error": {
openFunc: func() (io.ReadWriteCloser, error) { return nil, errors.New("error") },
pcrSelection: tpm2.PCRSelection{
Hash: tpm2.AlgSHA256,
PCRs: []int{0, 1, 2},
},
errExpected: true,
},
"3 PCRs": {
openFunc: OpenSimulatedTPM,
pcrSelection: tpm2.PCRSelection{
Hash: tpm2.AlgSHA256,
PCRs: []int{0, 1, 2},
},
},
"Azure PCRS": {
openFunc: OpenSimulatedTPM,
pcrSelection: AzurePCRSelection,
},
"GCP PCRs": {
openFunc: OpenSimulatedTPM,
pcrSelection: GCPPCRSelection,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
pcrs, err := GetSelectedPCRs(tc.openFunc, tc.pcrSelection)
if tc.errExpected {
assert.Error(err)
} else {
require.NoError(err)
assert.Equal(len(pcrs), len(tc.pcrSelection.PCRs))
for _, pcr := range pcrs {
assert.Len(pcr, 32)
}
}
})
}
}