constellation/internal/attestation/azure/snp/validator_test.go
Daniel Weiße e350ca0f57
attestation: add Azure TDX attestation (#2827)
* Implement Azure TDX attestation primitives
* Add default measurements and claims for Azure TDX
* Enable Constellation on Azure TDX

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
2024-01-24 15:10:15 +01:00

737 lines
23 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package snp
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/attestation/idkeydigest"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/attestation/snp/testdata"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/kds"
spb "github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNewValidator tests the creation of a new validator.
func TestNewValidator(t *testing.T) {
require := require.New(t)
testCases := map[string]struct {
cfg *config.AzureSEVSNP
logger attestation.Logger
}{
"success": {
cfg: config.DefaultForAzureSEVSNP(),
logger: logger.NewTest(t),
},
"nil logger": {
cfg: config.DefaultForAzureSEVSNP(),
logger: nil,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
validator := NewValidator(tc.cfg, tc.logger)
require.NotNil(validator)
require.NotNil(validator.log)
})
}
}
type stubHTTPSGetter struct {
urlResponseMatcher *urlResponseMatcher // maps responses to requested URLs
err error
}
func newStubHTTPSGetter(urlResponseMatcher *urlResponseMatcher, err error) *stubHTTPSGetter {
return &stubHTTPSGetter{
urlResponseMatcher: urlResponseMatcher,
err: err,
}
}
func (s *stubHTTPSGetter) Get(url string) ([]byte, error) {
if s.err != nil {
return nil, s.err
}
return s.urlResponseMatcher.match(url)
}
type urlResponseMatcher struct {
certChainResponse []byte
wantCertChainRequest bool
vcekResponse []byte
wantVcekRequest bool
}
func (m *urlResponseMatcher) match(url string) ([]byte, error) {
switch {
case url == "https://kdsintf.amd.com/vcek/v1/Milan/cert_chain":
if !m.wantCertChainRequest {
return nil, fmt.Errorf("unexpected cert_chain request")
}
return m.certChainResponse, nil
case regexp.MustCompile(`https:\/\/kdsintf.amd.com\/vcek\/v1\/Milan\/.*`).MatchString(url):
if !m.wantVcekRequest {
return nil, fmt.Errorf("unexpected VCEK request")
}
return m.vcekResponse, nil
default:
return nil, fmt.Errorf("unexpected URL: %s", url)
}
}
// TestCheckIDKeyDigest tests validation of an IDKeyDigest under different enforcement policies.
func TestCheckIDKeyDigest(t *testing.T) {
cfgWithAcceptedIDKeyDigests := func(enforcementPolicy idkeydigest.Enforcement, digestStrings []string) *config.AzureSEVSNP {
digests := idkeydigest.List{}
for _, digest := range digestStrings {
digests = append(digests, []byte(digest))
}
cfg := config.DefaultForAzureSEVSNP()
cfg.FirmwareSignerConfig.AcceptedKeyDigests = digests
cfg.FirmwareSignerConfig.EnforcementPolicy = enforcementPolicy
return cfg
}
reportWithIDKeyDigest := func(idKeyDigest string) *spb.Attestation {
report := &spb.Attestation{}
report.Report = &spb.Report{}
report.Report.IdKeyDigest = []byte(idKeyDigest)
return report
}
newTestValidator := func(cfg *config.AzureSEVSNP, validateTokenErr error) *Validator {
validator := NewValidator(cfg, logger.NewTest(t))
validator.maa = &stubMaaValidator{
validateTokenErr: validateTokenErr,
}
return validator
}
testCases := map[string]struct {
idKeyDigest string
acceptedIDKeyDigests []string
enforcementPolicy idkeydigest.Enforcement
validateMaaTokenErr error
wantErr bool
}{
"matching digest": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{"test"},
},
"no accepted digests": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{},
wantErr: true,
},
"mismatching digest, enforce": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{"other"},
wantErr: true,
},
"mismatching digest, maaFallback": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{"other"},
enforcementPolicy: idkeydigest.MAAFallback,
},
"mismatching digest, maaFallback errors": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{"other"},
enforcementPolicy: idkeydigest.MAAFallback,
validateMaaTokenErr: errors.New("maa fallback failed"),
wantErr: true,
},
"mismatching digest, warnOnly": {
idKeyDigest: "test",
acceptedIDKeyDigests: []string{"other"},
enforcementPolicy: idkeydigest.WarnOnly,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
require := require.New(t)
cfg := cfgWithAcceptedIDKeyDigests(tc.enforcementPolicy, tc.acceptedIDKeyDigests)
report := reportWithIDKeyDigest(tc.idKeyDigest)
validator := newTestValidator(cfg, tc.validateMaaTokenErr)
err := validator.checkIDKeyDigest(context.Background(), report, "", nil)
if tc.wantErr {
require.Error(err)
} else {
require.NoError(err)
}
})
}
}
type stubMaaValidator struct {
validateTokenErr error
}
func (v *stubMaaValidator) validateToken(_ context.Context, _ string, _ string, _ []byte) error {
return v.validateTokenErr
}
// TestGetTrustedKey tests the verification and validation of attestation report.
func TestTrustedKeyFromSNP(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
tpm, err := simulator.OpenSimulatedTPM()
require.NoError(t, err)
defer tpm.Close()
key, err := client.AttestationKeyRSA(tpm)
require.NoError(t, err)
defer key.Close()
akPub, err := key.PublicArea().Encode()
require.NoError(t, err)
defaultCfg := config.DefaultForAzureSEVSNP()
defaultReport := hex.EncodeToString(testdata.AttestationReport)
defaultRuntimeData := hex.EncodeToString(testdata.RuntimeData)
defaultIDKeyDigestOld, err := hex.DecodeString("57e229e0ffe5fa92d0faddff6cae0e61c926fc9ef9afd20a8b8cfcf7129db9338cbe5bf3f6987733a2bf65d06dc38fc1")
require.NoError(t, err)
defaultIDKeyDigest := idkeydigest.NewList([][]byte{defaultIDKeyDigestOld})
defaultVerifier := &stubAttestationVerifier{}
skipVerifier := &stubAttestationVerifier{skipCheck: true}
defaultValidator := &stubAttestationValidator{}
// reportTransformer unpacks the hex-encoded report, applies the given transformations and re-encodes it.
reportTransformer := func(reportHex string, transformations func(*spb.Report)) string {
rawReport, err := hex.DecodeString(reportHex)
require.NoError(t, err)
report, err := abi.ReportToProto(rawReport)
require.NoError(t, err)
transformations(report)
reportBytes, err := abi.ReportToAbiBytes(report)
require.NoError(t, err)
return hex.EncodeToString(reportBytes)
}
testCases := map[string]struct {
report string
runtimeData string
vcek []byte
certChain []byte
acceptedIDKeyDigests idkeydigest.List
enforcementPolicy idkeydigest.Enforcement
getter *stubHTTPSGetter
verifier *stubAttestationVerifier
validator *stubAttestationValidator
wantErr bool
assertion func(*assert.Assertions, error)
}{
"success": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
},
"certificate fetch error": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
getter: newStubHTTPSGetter(
nil,
assert.AnError,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "retrieving VCEK certificate from AMD KDS")
},
},
"fetch vcek and certchain": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
getter: newStubHTTPSGetter(
&urlResponseMatcher{
vcekResponse: testdata.AmdKdsVCEK,
wantVcekRequest: true,
certChainResponse: testdata.CertChain,
wantCertChainRequest: true,
},
nil,
),
},
"fetch vcek": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{
vcekResponse: testdata.AmdKdsVCEK,
wantVcekRequest: true,
},
nil,
),
},
"fetch certchain": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
getter: newStubHTTPSGetter(
&urlResponseMatcher{
certChainResponse: testdata.CertChain,
wantCertChainRequest: true,
},
nil,
),
},
"invalid report signature": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
r.Signature = make([]byte, 512)
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "report signature verification error")
},
},
"invalid vcek": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: []byte("invalid"),
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{
vcekResponse: []byte("invalid"),
wantVcekRequest: true,
},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "could not interpret VCEK DER bytes: x509: malformed certificate")
},
},
"invalid certchain fall back to embedded": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: []byte("invalid"),
getter: newStubHTTPSGetter(
&urlResponseMatcher{
certChainResponse: []byte("invalid"),
wantCertChainRequest: true,
},
nil,
),
wantErr: true,
},
"invalid runtime data": {
report: defaultReport,
runtimeData: defaultRuntimeData[:len(defaultRuntimeData)-10],
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "validating HCLAkPub: unmarshalling json: unexpected end of JSON input")
},
},
"inacceptable idkeydigest (wrong size), enforce": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: idkeydigest.List{[]byte{0x00}},
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "bad hash size in TrustedIDKeyHashes")
},
},
"inacceptable idkeydigest (wrong value), enforce": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: idkeydigest.List{make([]byte, 48)},
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "report ID key not trusted")
},
},
"inacceptable idkeydigest, warn only": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: idkeydigest.List{[]byte{0x00}},
enforcementPolicy: idkeydigest.WarnOnly,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
},
"launch tcb < minimum launch tcb": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
launchTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.LaunchTcb))
defaultCfg.MicrocodeVersion.Value = 10
launchTcb.UcodeSpl = 9
newLaunchTcb, err := kds.ComposeTCBParts(launchTcb)
require.NoError(t, err)
r.LaunchTcb = uint64(newLaunchTcb)
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.WarnOnly,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "is lower than the policy minimum launch TCB")
},
},
"reported tcb < minimum tcb": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
reportedTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.ReportedTcb))
reportedTcb.UcodeSpl = defaultCfg.MicrocodeVersion.Value - 1
newReportedTcb, err := kds.ComposeTCBParts(reportedTcb)
require.NoError(t, err)
r.ReportedTcb = uint64(newReportedTcb)
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.WarnOnly,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "is lower than the policy minimum TCB")
},
},
"current tcb < committed tcb": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
r.CurrentTcb = r.CommittedTcb - 1
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.WarnOnly,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "is lower than the report's COMMITTED_TCB")
},
},
"current tcb < tcb in vcek": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
currentTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.CurrentTcb))
currentTcb.UcodeSpl = 0x5c // testdata.AzureThimVCEK has ucode version 0x5d
newCurrentTcb, err := kds.ComposeTCBParts(currentTcb)
require.NoError(t, err)
r.CurrentTcb = uint64(newCurrentTcb)
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.WarnOnly,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "is lower than the TCB of the V[CL]EK certificate")
},
},
"reported tcb != tcb in vcek": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
r.ReportedTcb = uint64(0)
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.WarnOnly,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "does not match the TCB of the V[CL]EK certificate")
},
},
"vmpl != 0": {
report: reportTransformer(defaultReport, func(r *spb.Report) {
r.Vmpl = 1
}),
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: skipVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: testdata.CertChain,
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "report VMPL 1 is not 0")
},
},
"invalid ASK": {
report: defaultReport,
runtimeData: defaultRuntimeData,
acceptedIDKeyDigests: defaultIDKeyDigest,
enforcementPolicy: idkeydigest.Equal,
verifier: defaultVerifier,
validator: defaultValidator,
vcek: testdata.AzureThimVCEK,
certChain: func() []byte {
c := make([]byte, len(testdata.CertChain))
copy(c, testdata.CertChain)
c[1676] = 0x41 // somewhere in the ASK signature
return c
}(),
getter: newStubHTTPSGetter(
&urlResponseMatcher{},
nil,
),
wantErr: true,
assertion: func(assert *assert.Assertions, err error) {
assert.ErrorContains(err, "crypto/rsa: verification error")
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// This is important. Without this call, the trust module caches certificates across testcases.
defer trust.ClearProductCertCache()
instanceInfo, err := newStubInstanceInfo(tc.vcek, tc.certChain, tc.report, tc.runtimeData)
require.NoError(err)
statement, err := json.Marshal(instanceInfo)
require.NoError(err)
attDoc := vtpm.AttestationDocument{
InstanceInfo: statement,
Attestation: &attest.Attestation{
AkPub: akPub,
},
}
defaultCfg.FirmwareSignerConfig = config.SNPFirmwareSignerConfig{
AcceptedKeyDigests: tc.acceptedIDKeyDigests,
EnforcementPolicy: tc.enforcementPolicy,
}
validator := &Validator{
hclValidator: &stubAttestationKey{},
config: defaultCfg,
log: logger.NewTest(t),
getter: tc.getter,
attestationVerifier: tc.verifier,
attestationValidator: tc.validator,
}
key, err := validator.getTrustedKey(context.Background(), attDoc, nil)
if tc.wantErr {
assert.Error(err)
if tc.assertion != nil {
tc.assertion(assert, err)
}
} else {
assert.NoError(err)
assert.NotNil(key)
}
})
}
}
type stubAttestationVerifier struct {
skipCheck bool // whether the verification function should be called
}
// SNPAttestation verifies the VCEK certificate as well as the certificate chain of the attestation report.
func (v *stubAttestationVerifier) SNPAttestation(attestation *spb.Attestation, options *verify.Options) error {
if v.skipCheck {
return nil
}
return verify.SnpAttestation(attestation, options)
}
type stubAttestationValidator struct {
skipCheck bool // whether the verification function should be called
}
// SNPAttestation validates the attestation report against the given set of constraints.
func (v *stubAttestationValidator) SNPAttestation(attestation *spb.Attestation, options *validate.Options) error {
if v.skipCheck {
return nil
}
return validate.SnpAttestation(attestation, options)
}
type stubInstanceInfo struct {
AttestationReport []byte
ReportSigner []byte
CertChain []byte
Azure *stubAzureInstanceInfo
}
type stubAzureInstanceInfo struct {
RuntimeData []byte
}
func newStubInstanceInfo(vcek, certChain []byte, report, runtimeData string) (stubInstanceInfo, error) {
validReport, err := hex.DecodeString(report)
if err != nil {
return stubInstanceInfo{}, fmt.Errorf("invalid hex string report: %s", err)
}
decodedRuntime, err := hex.DecodeString(runtimeData)
if err != nil {
return stubInstanceInfo{}, fmt.Errorf("invalid hex string runtimeData: %s", err)
}
return stubInstanceInfo{
AttestationReport: validReport,
ReportSigner: vcek,
CertChain: certChain,
Azure: &stubAzureInstanceInfo{
RuntimeData: decodedRuntime,
},
}, nil
}
type stubAttestationKey struct{}
func (s *stubAttestationKey) Validate(runtimeDataRaw []byte, reportData []byte, _ *tpm2.RSAParams) error {
if err := json.Unmarshal(runtimeDataRaw, s); err != nil {
return fmt.Errorf("unmarshalling json: %w", err)
}
sum := sha256.Sum256(runtimeDataRaw)
if !bytes.Equal(sum[:], reportData[:32]) {
return fmt.Errorf("unexpected runtimeData digest in TPM")
}
return nil
}