constellation/internal/attestation/azure/azure_test.go

166 lines
4.6 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"os"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/attestation/snp"
"github.com/google/go-tpm-tools/client"
tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestValidateAk tests the attestation key validation with a simulated TPM device.
func TestValidateAk(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
int32ToBytes := func(val uint32) []byte {
r := make([]byte, 4)
binary.PutUvarint(r, uint64(val))
return r
}
require := require.New(t)
tpm, err := simulator.OpenSimulatedTPM()
require.NoError(err)
defer tpm.Close()
key, err := client.AttestationKeyRSA(tpm)
require.NoError(err)
defer key.Close()
e := base64.RawURLEncoding.EncodeToString(int32ToBytes(key.PublicArea().RSAParameters.ExponentRaw))
n := base64.RawURLEncoding.EncodeToString(key.PublicArea().RSAParameters.ModulusRaw)
ak := akPub{E: e, N: n}
rtData := runtimeData{PublicPart: []akPub{ak}}
defaultRuntimeDataRaw, err := json.Marshal(rtData)
require.NoError(err)
defaultInstanceInfo := snp.InstanceInfo{Azure: &snp.AzureInstanceInfo{RuntimeData: defaultRuntimeDataRaw}}
sig := sha256.Sum256(defaultRuntimeDataRaw)
defaultReportData := sig[:]
defaultRsaParams := key.PublicArea().RSAParameters
testCases := map[string]struct {
instanceInfo snp.InstanceInfo
runtimeDataRaw []byte
reportData []byte
rsaParameters *tpm2.RSAParams
wantErr bool
}{
"success": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: defaultRsaParams,
},
"invalid json": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: []byte(""),
reportData: defaultReportData,
rsaParameters: defaultRsaParams,
wantErr: true,
},
"invalid hash": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: bytes.Repeat([]byte{0}, 64),
rsaParameters: defaultRsaParams,
wantErr: true,
},
"invalid E": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: func() *tpm2.RSAParams {
tmp := *defaultRsaParams
tmp.ExponentRaw = 1
return &tmp
}(),
wantErr: true,
},
"invalid N": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: func() *tpm2.RSAParams {
tmp := *defaultRsaParams
tmp.ModulusRaw = []byte{0, 1, 2, 3}
return &tmp
}(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ak := HCLAkValidator{}
err = ak.Validate(tc.runtimeDataRaw, tc.reportData, tc.rsaParameters)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
// TestGetHCLAttestationKey is a basic smoke test that only checks if GetAttestationKey can be run error free.
// Testing anything else will only verify that the simulator works as expected, since GetAttestationKey
// only retrieves the attestation key from the TPM.
func TestGetHCLAttestationKey(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
require := require.New(t)
assert := assert.New(t)
tpm, err := simulator.OpenSimulatedTPM()
require.NoError(err)
defer tpm.Close()
// we should receive an error if no key was saved at index `tpmAkIdx`
_, err = GetAttestationKey(tpm)
assert.Error(err)
// create a key at the index
tpmAk, err := tpmclient.NewCachedKey(tpm, tpm2.HandleOwner, tpm2.Public{
Type: tpm2.AlgRSA,
NameAlg: tpm2.AlgSHA256,
Attributes: tpm2.FlagFixedTPM | tpm2.FlagFixedParent | tpm2.FlagSensitiveDataOrigin | tpm2.FlagUserWithAuth | tpm2.FlagNoDA | tpm2.FlagRestricted | tpm2.FlagSign,
RSAParameters: &tpm2.RSAParams{
Sign: &tpm2.SigScheme{
Alg: tpm2.AlgRSASSA,
Hash: tpm2.AlgSHA256,
},
KeyBits: 2048,
},
}, tpmAkIdx)
require.NoError(err)
defer tpmAk.Close()
// we should now be able to retrieve the key
_, err = GetAttestationKey(tpm)
assert.NoError(err)
}